Compare commits

...

473 Commits

Author SHA1 Message Date
comfyanonymous
418eb7062d Support new LTXV VAE. 2024-12-20 04:38:29 -05:00
comfyanonymous
cac68ca813 Fix some more video tiled encode issues.
The downscale_ratio formula for the temporal had issues with some frame
numbers.
2024-12-19 23:14:03 -05:00
comfyanonymous
52c1d933b2 Fix tiled hunyuan video VAE encode issue.
Some shapes like 1024x1024 with tile_size 256 and overlap 64 had issues.
2024-12-19 22:55:15 -05:00
catboxanon
3cacd3fca5 Support preview images embedded in safetensors metadata (#6119)
* Support preview images embedded in safetensors metadata

* Add unit test for safetensors embedded image previews
2024-12-19 14:01:56 -08:00
comfyanonymous
2dda7c11a3 More proper fix for the memory issue. 2024-12-19 16:21:56 -05:00
comfyanonymous
3ad3248ad7 Fix lowvram bug when using a model multiple times in a row.
The memory system would load an extra 64MB each time until either the
model was completely in memory or OOM.
2024-12-19 16:04:56 -05:00
comfyanonymous
c441048a4f Make VAE Encode tiled node work with video VAE. 2024-12-19 05:31:39 -05:00
comfyanonymous
9f4b181ab3 Add fast previews for hunyuan video. 2024-12-18 18:24:23 -05:00
comfyanonymous
cbbf077593 Small optimizations. 2024-12-18 18:23:28 -05:00
Chenlei Hu
0c04a6ae78 Add .github folder to maintainer owner list (#6027) 2024-12-18 15:06:53 -05:00
Chenlei Hu
416ccc9e45 Update web content to release v1.5.19 (#6105) 2024-12-18 15:06:20 -05:00
comfyanonymous
ff2ff02168 Support old diffusion-pipe hunyuan video loras. 2024-12-18 06:23:54 -05:00
comfyanonymous
4c5c4ddeda Fix regression in VAE code on old pytorch versions. 2024-12-18 03:08:28 -05:00
comfyanonymous
79badea452 Add ConditioningStableAudio.
This lets you control the seconds_start and seconds_total parameters for
the Stable Audio model.
2024-12-18 03:01:12 -05:00
comfyanonymous
37e5390f5f Add: --use-sage-attention to enable SageAttention.
You need to have the library installed first.
2024-12-18 01:56:10 -05:00
comfyanonymous
a4f59bc65e Pick attention implementation based on device in llama code. 2024-12-18 01:30:20 -05:00
comfyanonymous
ca457f7ba1 Properly tokenize the template for hunyuan video. 2024-12-17 16:22:02 -05:00
comfyanonymous
cd6f615038 Fix tiled vae not working with some shapes. 2024-12-17 16:22:02 -05:00
Terry Jia
517669aaa3 add preview 3d node (#6070)
* add preview 3d node

* mark 3d nodes as EXPERIMENTAL
2024-12-17 10:42:24 -08:00
comfyanonymous
e4e1bff605 Support diffusion-pipe hunyuan video lora format. 2024-12-17 07:14:21 -05:00
comfyanonymous
d6656b0c0c Support llama hunyuan video text encoder in scaled fp8 format. 2024-12-17 04:19:22 -05:00
comfyanonymous
f4cdedea62 Fix regression with ltxv VAE. 2024-12-17 02:17:31 -05:00
comfyanonymous
39b1fc4ccc Adjust used dtypes for hunyuan video VAE and diffusion model. 2024-12-16 23:31:10 -05:00
comfyanonymous
0b25f47bd9 Add some missing imports. 2024-12-16 19:42:01 -05:00
comfyanonymous
bda1482a27 Basic Hunyuan Video model support. 2024-12-16 19:35:40 -05:00
comfyanonymous
19ee5d9d8b Don't expand mask when not necessary.
Expanding seems to slow down inference.
2024-12-16 18:22:50 -05:00
Raphael Walker
61b50720d0 Add support for attention masking in Flux (#5942)
* fix attention OOM in xformers

* allow passing attention mask in flux attention

* allow an attn_mask in flux

* attn masks can be done using replace patches instead of a separate dict

* fix return types

* fix return order

* enumerate

* patch the right keys

* arg names

* fix a silly bug

* fix xformers masks

* replace match with if, elif, else

* mask with image_ref_size

* remove unused import

* remove unused import 2

* fix pytorch/xformers attention

This corrects a weird inconsistency with skip_reshape.
It also allows masks of various shapes to be passed, which will be
automtically expanded (in a memory-efficient way) to a size that is
compatible with xformers or pytorch sdpa respectively.

* fix mask shapes
2024-12-16 18:21:17 -05:00
Alexander Dyadyun
0f954f34af Update README.md (#6071)
The last ROCM 6.2 build was November 22nd, after that date new builds use ROCM 6.2.4.

The builds from the new URL have been tested and work without problems.
2024-12-16 15:24:54 -05:00
Chenlei Hu
5262901c5c Update web content to release v1.5.18 (#6075) 2024-12-16 11:38:24 -08:00
Terry Jia
cc550d5908 use String directly to set bg color for load 3d canvas (#6057) 2024-12-16 10:51:40 -08:00
comfyanonymous
6d1a3f7d00 Fix case of ExecutionBlocker not handled correctly with INPUT_IS_LIST. 2024-12-15 08:41:35 -05:00
Alexander Piskun
1b3a650f19 (fix): added "model_type" to photomaker node (#6047) 2024-12-15 00:18:02 -05:00
comfyanonymous
e83063bf24 Support conv3d in PatchEmbed. 2024-12-14 05:46:04 -05:00
Dr.Lt.Data
558b7d8b22 fix: prestartup script is not applied due to extra_model_paths.yaml and ensure custom paths are used during startup (#5872)
* fix: The custom nodes installed in the paths specified in `extra_model_paths.yaml` encounter a bug where the prestartup script is not imported.

* Ensure custom paths are used during startup
https://github.com/comfyanonymous/ComfyUI/pull/5794
2024-12-13 18:21:32 -05:00
Alexander Piskun
caf2074773 add_model_folder_path: ensure unique paths by removing duplicates (#5998)
* add_model_folder_path: ensure unique paths by removing duplicates

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* refactored "add_model_folder_path" and added tests

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-13 18:19:22 -05:00
Terry Jia
bdf393792d add load 3d node support (#5564)
* add load 3d node support

* remove Preview3D from BE
2024-12-13 18:13:52 -05:00
comfyanonymous
4e14032c02 Make pad_to_patch_size function work on multi dim. 2024-12-13 07:22:05 -05:00
Chenlei Hu
59d58b1158 [Security] Fix potential XSS on /view (#6034) 2024-12-13 04:56:43 -05:00
Chenlei Hu
563291ee51 Enforce all pyflake lint rules (#6033)
* Enforce F821 undefined-name

* Enforce all pyflake lint rules
2024-12-12 19:29:37 -05:00
Chenlei Hu
6c0377f43e Enforce F821 undefined-name (#6032) 2024-12-12 19:24:41 -05:00
Chenlei Hu
2cddbf0821 Lint and fix undefined names (1/N) (#6028) 2024-12-12 18:55:26 -05:00
Chenlei Hu
60749f345d Lint and fix undefined names (3/N) (#6030) 2024-12-12 18:49:40 -05:00
Chenlei Hu
d4426dce7c Lint and fix undefined names (2/N) (#6029) 2024-12-12 18:48:21 -05:00
Chenlei Hu
d9d7f3c619 Lint all unused variables (#5989)
* Enable F841

* Autofix

* Remove all unused variable assignment
2024-12-12 17:59:16 -05:00
comfyanonymous
fd5dfb812c Set initial load devices for te and model to mps device on mac. 2024-12-12 06:00:31 -05:00
Chenlei Hu
3dfdddcc91 Update README (Add new keybinding entries) (#6020) 2024-12-11 15:55:38 -08:00
Hayden
5747bc6457 Optimize model library (#5841)
* Move model manager routes

* Add experiment model manager api

* Fix cache causing returns to be empty

* Fix unable to compare sub-dir caches

* Skip non-existent folders

* Add model preview

* Revert 'Move model manager routes'

* move model_filemanager.py to app/

* Update model_manager.py

3.8 compatibility

---------
2024-12-11 18:12:04 -05:00
yoinked
5bea1d2ec9 Add MaHiRo (improved/alternate CFG) (#5975)
* Add MaHiRo (improved CFG)

long explanation of what it is is [here](https://huggingface.co/spaces/yoinked/blue-arxiv) (2024-1208.1) 


note: if the node name has encoding issues (utf 8/whatever), id suggest to replace the face at the end with `(>w<)`

* add it to nodes.py, add description, and make it a post_cfg function

* fix

* revert the sampler_cfg_function thing

* switch cfg to args["denoised"]
2024-12-11 16:51:51 -05:00
Yoland Yan
5def9fbc83 Update CI workflow to remove Windows testing configuration (#6007)
- Commented out Windows OS from the CI matrix in test-ci.yml.
- Removed the test-win-nightly job to streamline testing on macOS and Linux only.
- Adjusted the matrix strategy to focus on Python versions and CUDA compatibility without Windows support.
2024-12-11 16:48:41 -05:00
comfyanonymous
7a7efe8424 Support loading some checkpoint files with nested dicts. 2024-12-11 08:04:54 -05:00
comfyanonymous
44db978531 Fix a few things in text enc code for models with no eos token. 2024-12-10 23:07:26 -05:00
comfyanonymous
1c8d11e48a Support different types of tokenizers.
Support tokenizers without an eos token.

Pass full sentences to tokenizer for more efficient tokenizing.
2024-12-10 15:03:39 -05:00
Chenlei Hu
a220d11e6b Replace pylint with ruff (#5987) 2024-12-09 22:04:23 -05:00
catboxanon
23827ca312 Add cond_scale to sampler_post_cfg_function (#5985) 2024-12-09 20:13:18 -05:00
Chenlei Hu
0fd4e6c778 Lint unused import (#5973)
* Lint unused import

* nit

* Remove unused imports

* revert fix_torch import

* nit
2024-12-09 15:24:39 -05:00
comfyanonymous
e2fafe0686 Make CLIP set last layer node work with t5 models. 2024-12-09 03:57:14 -05:00
comfyanonymous
6579632201 Remove unused imports and variables. 2024-12-08 08:08:12 -05:00
comfyanonymous
ac2f0523ca Set env vars to disable telemetry in libs used by some custom nodes. 2024-12-07 14:51:45 -05:00
Haoming
fbf68c4e52 clamp input (#5928) 2024-12-07 14:00:31 -05:00
Chenlei Hu
93477f8efe Add code owners (#5873)
* Add code owners

* Update owners

* nit

* Inline owners

* Remove team links

* Add Kosinkadink
2024-12-06 22:00:54 -05:00
comfyanonymous
8af9a91e0c A few improvements to #5937. 2024-12-06 05:49:15 -05:00
Michael Kupchick
005d2d3a13 ltxv: add noise to guidance image to ensure generated motion. (#5937) 2024-12-06 05:46:08 -05:00
comfyanonymous
1e21f4c14e Make timestep ranges more usable on rectified flow models.
This breaks some old workflows but should make the nodes actually useful.
2024-12-05 16:40:58 -05:00
comfyanonymous
9a616b81c1 Add rescaling_scale from STG to SkipLayerGuidanceDiT. 2024-12-04 19:25:50 -05:00
comfyanonymous
3bed56bb13 Add another ROCm tip. 2024-12-04 15:14:12 -05:00
filtered
4e402b11c6 Reland union type (#5900)
* Reapply "Add union link connection type support (#5806)" (#5889)

This reverts commit bf9a90a145.

* Fix union type breaks existing type workarounds

* Add non-string test

* Add tests for hacks and non-string types

* Support python versions lower than 3.11
2024-12-04 15:12:10 -05:00
Chenlei Hu
48272448ad [Developer Experience] Add node typing (#5676)
* [Developer Experience] Add node typing

* Shim StrEnum

* nit

* nit

* nit
2024-12-04 15:01:00 -05:00
Jedrzej Kosinski
f7695b5f9e Add Create Hook Keyframes Interp. node to simplify creating groups of keyframes without external nodes (#5896) 2024-12-03 21:03:09 -05:00
comfyanonymous
452179fe4f Make ModelPatcher class clone function work with inheritance. 2024-12-03 13:57:57 -05:00
Chenlei Hu
bf9a90a145 Revert "Add union link connection type support (#5806)" (#5889)
This reverts commit 8d4e06324f.
2024-12-03 13:06:34 -05:00
comfyanonymous
c1b92b719d Some optimizations to euler a. 2024-12-03 06:11:52 -05:00
Alexander Piskun
cdc3b97dd5 resolve relative paths in YAML configuration for extra model paths (#5847)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-03 06:02:01 -05:00
Chenlei Hu
8d4e06324f Add union link connection type support (#5806)
* Add union type support

* Move code

* nit
2024-12-03 05:46:00 -05:00
comfyanonymous
57e8bf6a9f Fix case where a memory leak could cause crash.
Now the only symptom of code messing up and keeping references to a model
object when it should not will be endless prints in the log instead of the
next workflow crashing ComfyUI.
2024-12-02 19:49:49 -05:00
Jedrzej Kosinski
0ee322ec5f ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)

* Initial changes to calc_cond_batch to eventually support hook_patches

* Added current_patcher property to BaseModel

* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature

* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch

* Added default_conds support in calc_cond_batch func

* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring

* Made CLIP work with hook patches

* Added initial hook scheduling nodes, small renaming/refactoring

* Fixed MaxSpeed and default conds implementations

* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time

* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node

* Initial work on adding 'model_as_lora' lora type to calculate_weight

* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models

* Fix incorrect ref to create_hook_patches_clone after moving function

* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks

* Added wrappers to ModelPatcher to facilitate standardized function wrapping

* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type

* Fix skip_until_exit logic bug breaking injection after first run of model

* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load

* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers

* Refactored callbacks+wrappers to allow storing lists by id

* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks

* Added get_attachment func on ModelPatcher

* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup

* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref

* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations

* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)

* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)

* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)

* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling

* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances

* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks

* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge

* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes

* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges

* Optimized CLIP hook scheduling to treat same strength as same keyframe

* Less fragile memory management.

* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict

* Fix issue.

* Remove useless function.

* Prevent and detect some types of memory leaks.

* Run garbage collector when switching workflow if needed.

* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py

* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers

* Fix issue.

* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming

* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options

* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired

* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality

* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP

* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky

* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)

* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher

* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate

* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True

* Fix cached_hook_patches not respecting target_device/memory_counter results

* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches

* Remove unnecessary torch.no_grad calls for hook patches

* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists

* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes

* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call

* Made encode_from_tokens_scheduled work when no hooks are set on patcher

* Small cleanup of comments

* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case

* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs

* Allow both FLOATS and FLOAT for floats_strength input

* Revert change, does not work

* Made patch_hook_weight_to_device respect set_func and convert_func

* Make discard_model_sampling True by default

* Add changes manually from 'master' so merge conflict resolution goes more smoothly

* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call

* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip

* Made nodes in nodes_hooks be marked as experimental (beta)

* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references

* Made finalize_default_conds area math consistent with other sampling code

* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'

* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 14:51:02 -05:00
comfyanonymous
79d5ceae6e Improved memory management. (#5450)
* Less fragile memory management.

* Fix issue.

* Remove useless function.

* Prevent and detect some types of memory leaks.

* Run garbage collector when switching workflow if needed.

* Fix issue.
2024-12-02 14:39:34 -05:00
comfyanonymous
2d5b3e0078 Remove useless code. 2024-12-02 06:49:55 -05:00
comfyanonymous
8e4118c0de make dpm_2_ancestral work with rectified flow. 2024-12-01 07:37:41 -05:00
comfyanonymous
3fc6ebcdd7 Add basic style model "multiply" strength. 2024-11-30 07:27:11 -05:00
comfyanonymous
20a560eb97 How to enable experimental memory efficient attention on ROCm RDNA3. 2024-11-29 06:19:49 -05:00
Dr.Lt.Data
82c5308561 Backward compatibility patch for changes in the method signature of InpaintModelConditioning. (#5825)
https://github.com/comfyanonymous/ComfyUI/issues/5813
2024-11-28 20:30:28 -05:00
comfyanonymous
26fb2c68e8 Add a way to disable cropping in the CLIPVisionEncode node. 2024-11-28 20:24:47 -05:00
comfyanonymous
bf2650a80e Fast previews for ltxv. 2024-11-28 06:46:15 -05:00
Chenlei Hu
53646e0f32 Update web content to release v1.4.13 (#5807) 2024-11-28 04:59:06 -05:00
Chenlei Hu
20879c78f9 Remove internal model download endpoint (#5432) 2024-11-28 04:57:06 -05:00
comfyanonymous
b666539595 Remove print. 2024-11-27 20:28:39 -05:00
comfyanonymous
95d8713482 Missing parentheses. 2024-11-27 13:45:32 -05:00
comfyanonymous
0d4e29f13f LTXV model merging node. 2024-11-27 01:43:31 -05:00
comfyanonymous
497db6212f Alternative fix for #5767 2024-11-26 17:53:04 -05:00
lky
24dc581dc3 fix multi add makedirs error (#5786)
try to start multiple comfyui server at the same time, and this got error
2024-11-26 15:34:19 -05:00
comfyanonymous
4c82741b54 Support official SD3.5 Controlnets. 2024-11-26 11:31:25 -05:00
comfyanonymous
15c39ea757 Support for the official mochi lora format. 2024-11-26 03:34:36 -05:00
comfyanonymous
b7143b74ce Flux inpaint model does not work in fp16. 2024-11-26 01:33:01 -05:00
comfyanonymous
61196d8857 Add option to inference the diffusion model in fp32 and fp64. 2024-11-25 05:00:23 -05:00
comfyanonymous
b4526d3fc3 Skip layer guidance now works on hydit model. 2024-11-24 05:54:30 -05:00
40476
3d802710e7 Update README.md (#5707) 2024-11-24 04:12:07 -05:00
spacepxl
7126ecffde set LTX min length to 1 for t2i (#5750)
At length=1, the LTX model can do txt2img and img2img with no other changes required.
2024-11-23 21:33:08 -05:00
comfyanonymous
ab885b33ba Skip layer guidance node now works on LTX-Video. 2024-11-23 10:33:05 -05:00
comfyanonymous
839ed3368e Some improvements to the lowvram unloading. 2024-11-22 20:59:15 -05:00
comfyanonymous
6e8cdcd3cb Fix some tiled VAE decoding issues with LTX-Video. 2024-11-22 18:00:34 -05:00
comfyanonymous
e5c3f4b87f LTXV lowvram fixes. 2024-11-22 17:17:11 -05:00
comfyanonymous
bc6be6c11e Some fixes to the lowvram system. 2024-11-22 16:40:04 -05:00
comfyanonymous
94323a26a7 Remove prints. 2024-11-22 10:51:31 -05:00
comfyanonymous
5818f6cf51 Remove print. 2024-11-22 10:49:15 -05:00
comfyanonymous
0b734de449 Add LTX-Video support to the Readme. 2024-11-22 09:24:20 -05:00
comfyanonymous
5e16f1d24b Support Lightricks LTX-Video model. 2024-11-22 08:46:39 -05:00
comfyanonymous
2fd9c1308a Fix mask issue in some attention functions. 2024-11-22 02:10:09 -05:00
comfyanonymous
8f0009aad0 Support new flux model variants. 2024-11-21 08:38:23 -05:00
comfyanonymous
41444b5236 Add some new weight patching functionality.
Add a way to reshape lora weights.

Allow weight patches to all weight not just .weight and .bias

Add a way for a lora to set a weight to a specific value.
2024-11-21 07:19:17 -05:00
comfyanonymous
772e620e32 Update readme. 2024-11-20 20:42:51 -05:00
comfyanonymous
07f6eeaa13 Fix mask issue with attention_xformers. 2024-11-20 17:07:46 -05:00
comfyanonymous
22535d0589 Skip layer guidance now works on stable audio model. 2024-11-20 07:33:06 -05:00
comfyanonymous
898615122f Rename add_noise_mask -> noise_mask. 2024-11-19 15:31:09 -05:00
comfyanonymous
156a28786b Add boolean to InpaintModelConditioning to disable the noise mask. 2024-11-19 07:31:29 -05:00
Yoland Yan
f498d855ba Add terminal size fallback (#5623) 2024-11-19 03:34:20 -05:00
comfyanonymous
b699a15062 Refactor inpaint/ip2p code. 2024-11-19 03:25:25 -05:00
Chenlei Hu
9cc90ee3eb Update UI screenshot in README (#5666)
* Update UI ScreenShot in README

* Remove legacy UI screenshot file

* nit

* nit
2024-11-18 16:50:34 -05:00
comfyanonymous
9a0a5d32ee Add a skip layer guidance node that can also skip single layers.
This one should work for skipping the single layers of models like Flux
and Auraflow.

If you want to see how these models work and how many double/single layers
they have see the "ModelMerge*" nodes for the specific model.
2024-11-18 02:20:43 -05:00
comfyanonymous
d9f90965c8 Support block replace patches in auraflow. 2024-11-17 08:19:59 -05:00
comfyanonymous
41886af138 Add transformer options blocks replace patch to mochi. 2024-11-16 20:48:14 -05:00
Chenlei Hu
22a1d7ce78 Fix 3.8 compatibility in user_manager.py (#5645) 2024-11-16 20:42:21 -05:00
Chenlei Hu
4ac401af2b Update web content to release v1.3.44 (#5620)
* Update web content to release v1.3.44

* nit
2024-11-15 20:17:15 -05:00
comfyanonymous
5fb59c8475 Add a node to block merge auraflow models. 2024-11-15 12:47:55 -05:00
comfyanonymous
122c9ca1ce Add advanced model merging node for mochi. 2024-11-14 07:51:20 -05:00
comfyanonymous
3b9a6cf2b1 Fix issue with 3d masks. 2024-11-13 07:18:30 -05:00
comfyanonymous
3748e7ef7a Fix regression. 2024-11-13 04:24:48 -05:00
comfyanonymous
8ebf2d8831 Add block replace transformer_options to flux. 2024-11-12 08:00:39 -05:00
Bratzmeister
a72d152b0c fix --cuda-device arg for AMD/HIP devices (#5586)
* fix --cuda-device arg for AMD/HIP devices

CUDA_VISIBLE_DEVICES is ignored for HIP devices/backend. Instead it uses HIP_VISIBLE_DEVICES. Setting this environment variable has no side effect for CUDA/NVIDIA so it can safely be set in any case and vice versa.

* deleted accidental if
2024-11-12 06:53:36 -05:00
comfyanonymous
eb476e6ea9 Allow 1D masks for 1D latents. 2024-11-11 14:44:52 -05:00
Dr.Lt.Data
2d28b0b479 improve: add descriptions for clip loaders (#5576) 2024-11-11 05:37:23 -05:00
comfyanonymous
8b275ce5be Support auto detecting some zsnr anime checkpoints. 2024-11-11 05:34:11 -05:00
comfyanonymous
2a18e98ccf Refactor so that zsnr can be set in the sampling_settings. 2024-11-11 04:55:56 -05:00
comfyanonymous
8a5281006f Fix some custom nodes. 2024-11-10 22:41:00 -05:00
comfyanonymous
bdeb1c171c Fast previews for mochi. 2024-11-10 03:39:35 -05:00
comfyanonymous
9c1ed58ef2 proper fix for sag. 2024-11-10 00:10:45 -05:00
comfyanonymous
8b90e50979 Properly handle and reshape masks when used on 3d latents. 2024-11-09 15:30:19 -05:00
pythongosssss
6ee066a14f Live terminal output (#5396)
* Add /logs/raw and /logs/subscribe for getting logs on frontend
Hijacks stderr/stdout to send all output data to the client on flush

* Use existing send sync method

* Fix get_logs should return string

* Fix bug

* pass no server

* fix tests

* Fix output flush on linux
2024-11-08 19:13:34 -05:00
DenOfEquity
dd5b57e3d7 fix for SAG with Kohya HRFix/ Deep Shrink (#5546)
now works with arbitrary downscale factors
2024-11-08 18:16:29 -05:00
comfyanonymous
75a818c720 Move mochi latent node to: latent/video. 2024-11-08 08:33:44 -05:00
comfyanonymous
2865f913f7 Free memory before doing tiled decode. 2024-11-07 04:01:24 -05:00
comfyanonymous
b49616f951 Make VAEDecodeTiled node work with video VAEs. 2024-11-07 03:47:12 -05:00
comfyanonymous
5e29e7a488 Remove scaled_fp8 key after reading it to silence warning. 2024-11-06 04:56:42 -05:00
comfyanonymous
8afb97cd3f Fix unknown VAE being detected as the mochi VAE. 2024-11-05 03:43:27 -05:00
contentis
69694f40b3 fix dynamic shape export (#5490) 2024-11-04 14:59:28 -05:00
Chenlei Hu
c49025f01b Allow POST /userdata/{file} endpoint to return full file info (#5446)
* Refactor listuserdata

* Full info param

* Add tests

* Fix mock

* Add full_info support for move user file
2024-11-04 13:57:21 -05:00
comfyanonymous
696672905f Add mochi support to readme. 2024-11-04 04:55:07 -05:00
comfyanonymous
6c9dbde7de Fix mochi all in one checkpoint t5xxl key names. 2024-11-03 01:40:42 -05:00
comfyanonymous
ee8abf0cff Update folder paths: "clip" -> "text_encoders"
You can still use models/clip but the folder might get removed eventually
on new installs of ComfyUI.
2024-11-02 15:35:38 -04:00
comfyanonymous
fabf449feb Mochi VAE encoder. 2024-11-01 17:33:09 -04:00
Uriel Deveaud
cc9cf6d1bd Rename some nodes in Display Name Mappings (nodes.py) (#5439)
* Update nodes_images.py

Nodes menu has inconsistency in names, some with spaces between words, other not.

* Update nodes.py

Include the node mapping name line for Image Crop Node

* Update nodes_images.py

* Rename image nodes

add space between words for consistency > Display name mappings
2024-10-31 15:18:05 -04:00
Aarni Koskela
1c8286a44b Avoid SyntaxWarning in UniPC docstring (#5442) 2024-10-31 15:17:26 -04:00
comfyanonymous
1af4a47fd1 Bump up mac version for attention upcast bug workaround. 2024-10-31 15:15:31 -04:00
Uriel Deveaud
f2aaa0a475 Rename ImageCrop to Image Crop (#5424)
* Update nodes_images.py

Nodes menu has inconsistency in names, some with spaces between words, other not.

* Update nodes.py

Include the node mapping name line for Image Crop Node

* Update nodes_images.py
2024-10-31 00:35:34 -04:00
comfyanonymous
daa1565b93 Fix diffusers flux controlnet regression. 2024-10-30 13:11:34 -04:00
comfyanonymous
09fdb2b269 Support SD3.5 medium diffusers format weights and loras. 2024-10-30 04:24:00 -04:00
Chenlei Hu
65a8659182 Update web content to release v1.3.26 (#5413)
* Update web content to release v1.3.26

* nit
2024-10-29 14:14:06 -04:00
comfyanonymous
770ab200f2 Cleanup SkipLayerGuidanceSD3 node. 2024-10-29 10:11:46 -04:00
Dango233
954683d0db SLG first implementation for SD3.5 (#5404)
* SLG first implementation for SD3.5

* * Simplify and align with comfy style
2024-10-29 09:59:21 -04:00
comfyanonymous
30c0c81351 Add a way to patch blocks in SD3. 2024-10-29 00:48:32 -04:00
comfyanonymous
13b0ff8a6f Update SD3 code. 2024-10-28 21:58:52 -04:00
comfyanonymous
c320801187 Remove useless line. 2024-10-28 17:41:12 -04:00
Chenlei Hu
c0b0cfaeec Update web content to release v1.3.21 (#5351)
* Update web content to release v1.3.21

* nit
2024-10-28 14:29:38 -04:00
comfyanonymous
669d9e4c67 Set default shift on mochi to 6.0 2024-10-27 22:21:04 -04:00
comfyanonymous
9ee0a6553a float16 inference is a bit broken on mochi. 2024-10-27 04:56:40 -04:00
comfyanonymous
5cbb01bc2f Basic Genmo Mochi video model support.
To use:
"Load CLIP" node with t5xxl + type mochi
"Load Diffusion Model" node with the mochi dit file.
"Load VAE" with the mochi vae file.

EmptyMochiLatentVideo node for the latent.
euler + linear_quadratic in the KSampler node.
2024-10-26 06:54:00 -04:00
comfyanonymous
c3ffbae067 Make LatentUpscale nodes work on 3d latents. 2024-10-26 01:50:51 -04:00
comfyanonymous
d605677b33 Make euler_ancestral work on flow models (credit: Ashen). 2024-10-25 19:53:44 -04:00
Chenlei Hu
ce759b7db6 Revert download to .tmp in frontend_management (#5369) 2024-10-25 19:26:13 -04:00
comfyanonymous
52810907e2 Add a model merge node for SD3.5 large. 2024-10-24 16:46:21 -04:00
PsychoLogicAu
af8cf79a2d support SimpleTuner lycoris lora for SD3 (#5340) 2024-10-24 01:18:32 -04:00
comfyanonymous
66b0961a46 Fix ControlLora issue with last commit. 2024-10-23 17:02:40 -04:00
comfyanonymous
754597c8a9 Clean up some controlnet code.
Remove self.device which was useless.
2024-10-23 14:19:05 -04:00
comfyanonymous
915fdb5745 Fix lowvram edge case. 2024-10-22 16:34:50 -04:00
contentis
5a8a48931a remove attention abstraction (#5324) 2024-10-22 14:02:38 -04:00
comfyanonymous
8ce2a1052c Optimizations to --fast and scaled fp8. 2024-10-22 02:12:28 -04:00
comfyanonymous
f82314fcfc Fix duplicate sigmas on beta scheduler. 2024-10-21 20:19:45 -04:00
comfyanonymous
0075c6d096 Mixed precision diffusion models with scaled fp8.
This change allows supports for diffusion models where all the linears are
scaled fp8 while the other weights are the original precision.
2024-10-21 18:12:51 -04:00
comfyanonymous
83ca891118 Support scaled fp8 t5xxl model. 2024-10-20 22:27:00 -04:00
comfyanonymous
f9f9faface Fixed model merging issue with scaled fp8. 2024-10-20 06:24:31 -04:00
comfyanonymous
471cd3eace fp8 casting is fast on GPUs that support fp8 compute. 2024-10-20 00:54:47 -04:00
comfyanonymous
a68bbafddb Support diffusion models with scaled fp8 weights. 2024-10-19 23:47:42 -04:00
comfyanonymous
73e3a9e676 Clamp output when rounding weight to prevent Nan. 2024-10-19 19:07:10 -04:00
comfyanonymous
518c0dc2fe Add tooltips to LoraSave node. 2024-10-18 06:01:09 -04:00
comfyanonymous
ce0542e10b Add a note that python 3.13 is not yet supported to the README. 2024-10-17 19:27:37 -04:00
comfyanonymous
8473019d40 Pytorch can be shipped with numpy 2 now. 2024-10-17 19:15:17 -04:00
Xiaodong Xie
89f15894dd Ignore more network related errors during websocket communication. (#5269)
Intermittent network issues during websocket communication should not crash ComfyUi process.

Co-authored-by: Xiaodong Xie <xie.xiaodong@frever.com>
2024-10-17 18:31:45 -04:00
comfyanonymous
67158994a4 Use the lowvram cast_to function for everything. 2024-10-17 17:25:56 -04:00
comfyanonymous
7390ff3b1e Add missing import. 2024-10-16 14:58:30 -04:00
comfyanonymous
0bedfb26af Revert "Fix Transformers FutureWarning (#5140)"
This reverts commit 95b7cf9bbe.
2024-10-16 12:36:19 -04:00
comfyanonymous
f71cfd2687 Add an experimental node to sharpen latents.
Can be used with LatentApplyOperationCFG for interesting results.
2024-10-16 05:25:31 -04:00
Alex "mcmonkey" Goodwin
c695c4af7f Frontend Manager: avoid redundant gh calls for static versions (#5152)
* Frontend Manager: avoid redundant gh calls for static versions

* actually, removing old tmpdir isn't needed

I tested - downloader code handles this case well already
(also rmdir was wrong func anyway, needed shutil.rmtree if it had content)

* add code comment
2024-10-16 03:35:37 -04:00
comfyanonymous
0dbba9f751 Add some latent operation nodes.
This is a port of the ModelSamplerTonemapNoiseTest from the experiments
repo.

To replicate that node use LatentOperationTonemapReinhard and
LatentApplyOperationCFG together.
2024-10-15 15:00:36 -04:00
comfyanonymous
f584758271 Cleanup some useless lines. 2024-10-14 21:02:39 -04:00
svdc
95b7cf9bbe Fix Transformers FutureWarning (#5140)
* Update sd1_clip.py

Fix Transformers FutureWarning

* Update sd1_clip.py

Fix comment
2024-10-14 20:12:20 -04:00
comfyanonymous
191a0d56b4 Switch default packaging workflows to python 3.12 2024-10-13 06:59:31 -04:00
comfyanonymous
3c60ecd7a8 Fix fp8 ops staying enabled. 2024-10-12 14:10:13 -04:00
comfyanonymous
7ae6626723 Remove useless argument. 2024-10-12 07:16:21 -04:00
comfyanonymous
6632365e16 model_options consistency between functions.
weight_dtype -> dtype
2024-10-11 20:51:19 -04:00
Kadir Nar
ad07796777 🐛 Add device to variable c (#5210) 2024-10-11 20:37:50 -04:00
comfyanonymous
1b80895285 Make clip loader nodes support loading sd3 t5xxl in lower precision.
Add attention mask support in the SD3 text encoder code.
2024-10-10 15:06:15 -04:00
Dr.Lt.Data
5f9d5a244b Hotfix for the div zero occurrence when memory_used_encode is 0 (#5121)
https://github.com/comfyanonymous/ComfyUI/issues/5069#issuecomment-2382656368
2024-10-09 23:34:34 -04:00
Chenlei Hu
14eba07acd Update web content to release v1.3.11 (#5189)
* Update web content to release v1.3.11

* nit
2024-10-09 22:37:04 -04:00
Jonathan Avila
4b2f0d9413 Increase maximum macOS version to 15.0.1 when forcing upcast attention (#5191) 2024-10-09 22:21:41 -04:00
Yoland Yan
25eac1d780 Change runner label for the new runners (#5197) 2024-10-09 20:08:57 -04:00
comfyanonymous
e38c94228b Add a weight_dtype fp8_e4m3fn_fast to the Diffusion Model Loader node.
This is used to load weights in fp8 and use fp8 matrix multiplication.
2024-10-09 19:43:17 -04:00
comfyanonymous
203942c8b2 Fix flux doras with diffusers keys. 2024-10-08 19:03:40 -04:00
Brendan Hoar
3c72c89a52 Update folder_paths.py - try/catch for special file_name values (#5187)
Somehow managed to drop a file called "nul" into a windows checkpoints subdirectory. This caused all sorts of havoc with many nodes that needed the list of checkpoints.
2024-10-08 15:04:32 -04:00
Chenlei Hu
614377abd6 Update web content to release v1.2.64 (#5124) 2024-10-07 17:15:29 -04:00
comfyanonymous
8dfa0cc552 Make SD3 fast previews a little better. 2024-10-07 09:19:59 -04:00
comfyanonymous
e5ecdfdd2d Make fast previews for SDXL a little better by adding a bias. 2024-10-06 19:27:04 -04:00
comfyanonymous
7d29fbf74b Slightly improve the fast previews for flux by adding a bias. 2024-10-06 17:55:46 -04:00
Lex
2c641e64ad IS_CHANGED should be a classmethod (#5159) 2024-10-06 05:47:51 -04:00
comfyanonymous
7d2467e830 Some minor cleanups. 2024-10-05 13:22:39 -04:00
comfyanonymous
6f021d8aa0 Let --verbose have an argument for the log level. 2024-10-04 10:05:34 -04:00
comfyanonymous
d854ed0bcf Allow using SD3 type te output on flux model. 2024-10-03 09:44:54 -04:00
comfyanonymous
abcd006b8c Allow more permutations of clip/t5 in dual clip loader. 2024-10-03 09:26:11 -04:00
comfyanonymous
d985d1d7dc CLIP Loader node now supports clip_l and clip_g only for SD3. 2024-10-02 04:25:17 -04:00
comfyanonymous
d1cdf51e1b Refactor some of the TE detection code. 2024-10-01 07:08:41 -04:00
comfyanonymous
b4626ab93e Add simpletuner lycoris format for SD unet. 2024-09-30 06:03:27 -04:00
comfyanonymous
a9e459c2a4 Use torch.nn.functional.linear in RGB preview code.
Add an optional bias to the latent RGB preview code.
2024-09-29 11:27:49 -04:00
comfyanonymous
3bb4dec720 Fix issue with loras, lowvram and --fast fp8. 2024-09-28 14:42:32 -04:00
City
8733191563 Flux torch.compile fix (#5082) 2024-09-27 22:07:51 -04:00
comfyanonymous
83b01f960a Add backend option to TorchCompileModel.
If you want to use the cudagraphs backend you need to: --disable-cuda-malloc

If you get other backends working feel free to make a PR to add them.
2024-09-27 02:12:37 -04:00
comfyanonymous
d72e871cfa Add a note that the experimental model downloader api will be removed. 2024-09-26 03:17:52 -04:00
comfyanonymous
037c3159b6 Move some nodes out of _for_testing. 2024-09-25 08:41:22 -04:00
comfyanonymous
bdd4a22a2e Fix flux TE not loading t5 embeddings. 2024-09-24 22:57:22 -04:00
comfyanonymous
fdf37566ef Add batch size to EmptyLatentAudio. 2024-09-24 04:32:55 -04:00
Alex "mcmonkey" Goodwin
08c8968482 Internal download API: Add proper validated directory input (#4981)
* add internal /folder_paths route

returns a json maps of folder paths

* (minor) format download_models.py

* initial folder path input on download api

* actually, require folder_path and clean up some code

* partial tests update

* fix & logging

* also download to a tmp file not the live file

to avoid compounding errors from network failure

* update tests again

* test tweaks

* workaround the first tests blocker

* fix file handling in tests

* rewrite test for create_model_path

* minor doc fix

* avoid 'mock_directory'

use temp dir to avoid accidental fs pollution from tests
2024-09-24 03:50:45 -04:00
chaObserv
479a427a48 Add dpmpp_2m_cfg_pp (#4992) 2024-09-24 02:42:56 -04:00
comfyanonymous
3a0eeee320 Make --listen listen on both ipv4 and ipv6 at the same time by default. 2024-09-23 04:38:19 -04:00
comfyanonymous
447da7ea86 Support listening on multiple addresses. 2024-09-23 04:36:59 -04:00
comfyanonymous
9c41bc8d10 Remove useless line. 2024-09-23 02:32:29 -04:00
Robin Huang
6ad0ddbae4 Run unit tests on Windows/MacOS as well. (#5018)
* Run unit tests on Windows as well.

* Test on mac.

* Continue running on error.

* Compared normalized paths to work cross platform.

* Only test common set of mimetypes across operating systems.
2024-09-22 05:01:39 -04:00
RandomGitUser321
a55142f904 Add ws.close() to the websocket examples (#5020)
* add ws.close() to websocket examples

* add and explain ws.close() in websocket examples
2024-09-22 04:59:10 -04:00
comfyanonymous
5718ef69bb Add total and free ram to /system_stats. 2024-09-22 03:42:11 -04:00
RandomGitUser321
13ecf10a92 Added to the websockets_api_example.py to show how to decode latent previews from the binary stream (#5016)
* Update websockets_api_example.py

* even more simplfied
2024-09-22 02:30:44 -04:00
comfyanonymous
7a415f47a9 Add an optional VAE input to the ControlNetApplyAdvanced node.
Deprecate the other controlnet nodes.
2024-09-22 01:24:52 -04:00
Chenlei Hu
89fa2fca24 Update web content to release v1.2.60 (#5017)
* Update web content to release v1.2.60

* Remove dist.zip
2024-09-21 23:28:54 -04:00
comfyanonymous
364b69e931 Make SD3 empty latent image zeros.
This shouldn't change anything. The reason it was not zeros is because it
did matter in early versions of the code.
2024-09-21 09:13:10 -04:00
comfyanonymous
dc96a1ae19 Load controlnet in fp8 if weights are in fp8. 2024-09-21 04:50:12 -04:00
comfyanonymous
2d810b081e Add load_controlnet_state_dict function. 2024-09-21 01:51:51 -04:00
comfyanonymous
9f7e9f0547 Add an error message when a controlnet needs a VAE but none is given. 2024-09-21 01:33:18 -04:00
comfyanonymous
a355f38ecc Make the SD3 controlnet node the default one. 2024-09-21 01:32:46 -04:00
huchenlei
38c69080c7 Add docstring 2024-09-20 03:16:23 -04:00
comfyanonymous
70a708d726 Fix model merging issue. 2024-09-20 02:31:44 -04:00
yoinked
e7d4782736 add laplace scheduler [2407.03297] (#4990)
* add laplace scheduler [2407.03297]

* should be here instead lol

* better settings
2024-09-19 23:23:09 -04:00
Alex "mcmonkey" Goodwin
3326bdfd4e add internal /folder_paths route (#4980)
returns a json maps of folder paths
2024-09-19 09:52:55 -04:00
Alex "mcmonkey" Goodwin
68bb885d22 add 'is_default' to model paths config (#4979)
* add 'is_default' to model paths config

including impl and doc in example file

* update weirdly overspecific test expectations

* oh there's two

* sigh
2024-09-19 08:59:55 -04:00
comfyanonymous
ad66f7c7d8 Add model_options to load_controlnet function. 2024-09-19 08:23:35 -04:00
Simon Lui
de8e8e3b0d Fix xpu Pytorch nightly build from calling optimize which doesn't exist. (#4978) 2024-09-19 05:11:42 -04:00
Alex "mcmonkey" Goodwin
a1e71cfad1 very simple strong-cache on model list (#4969)
* very simple strong-cache on model list

* store the cache after validation too

* only cache object_info for now

* use a 'with' context
2024-09-19 04:40:14 -04:00
comfyanonymous
0bfc7cc998 Create the temp directory on ComfyUI startup instead. 2024-09-18 09:55:57 -04:00
Tom
7183fd1665 Add route to list model types (#4846)
* Add list models route

* Better readable model types list
2024-09-17 04:22:05 -04:00
Alex "mcmonkey" Goodwin
254838f23c add simple error check to model loading (#4950) 2024-09-17 03:57:17 -04:00
pharmapsychotic
0b7dfa986d Improve tiling calculations to reduce number of tiles that need to be processed. (#4944) 2024-09-17 03:51:10 -04:00
comfyanonymous
d514bb38ee Add some option to model_options for the text encoder.
load_device, offload_device and the initial_device can now be set.
2024-09-17 03:49:54 -04:00
comfyanonymous
0849c80e2a get_key_patches now works without unloading the model. 2024-09-17 01:57:59 -04:00
comfyanonymous
56e8f5e4fd VAEDecodeAudio now does some normalization on the audio. 2024-09-16 00:30:36 -04:00
comfyanonymous
e813abbb2c Long CLIP L support for SDXL, SD3 and Flux.
Use the *CLIPLoader nodes.
2024-09-15 07:59:38 -04:00
JettHu
5e68a4ce67 Reduce repeated calls of INPUT_TYPES in cache (#4922) 2024-09-15 01:03:09 -04:00
comfyanonymous
ca08597670 Make the inpaint controlnet node work with non inpaint ones. 2024-09-14 09:17:13 -04:00
comfyanonymous
f48e390032 Support AliMama SD3 and Flux inpaint controlnets.
Use the ControlNetInpaintingAliMamaApply node.
2024-09-14 09:05:16 -04:00
Chenlei Hu
369a6dd2c4 Remove empty spaces in user_manager.py (#4917) 2024-09-13 23:30:44 -04:00
comfyanonymous
b3ce8fb9fd Revert "Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871)"
This reverts commit f6b7194f64.
2024-09-13 23:24:47 -04:00
comfyanonymous
cf80d28689 Support loading controlnets with different input. 2024-09-13 09:54:37 -04:00
Acly
6fb44c4b7c Make adding links/nodes to ExecutionList non-recursive (#4886)
Graphs with 300+ chained nodes run into maximum recursion depth error (limit is 1000 in CPython)
2024-09-13 08:25:11 -04:00
Chenlei Hu
d2247c1e61 Normalize path returned by /userdata to always use / as separator (#4906) 2024-09-13 03:45:31 -04:00
Chenlei Hu
cb12ad7049 Add full_info flag in /userdata endpoint to list out file size and last modified timestamp (#4905)
* Add full_info flag in /userdata endpoint to list out file size and last modified timestamp

* nit
2024-09-13 02:40:59 -04:00
JettHu
f6b7194f64 Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871) 2024-09-12 23:02:52 -04:00
comfyanonymous
7c6eb4fb29 Set some nodes as DEPRECATED. 2024-09-12 20:27:07 -04:00
Robin Huang
b962db9952 Add cli arg to override user directory (#4856)
* Override user directory.

* Use overridden user directory.

* Remove prints.

* Remove references to global user_files.

* Remove unused replace_folder function.

* Remove newline.

* Remove global during get_user_directory.

* Add validation.
2024-09-12 08:10:27 -04:00
comfyanonymous
d0b7ab88ba Add a simple experimental TorchCompileModel node.
It probably only works on Linux.

For maximum speed on Flux with Nvidia 40 series/ada and newer try using
this node with fp8_e4m3fn and the --fast argument.
2024-09-12 05:24:25 -04:00
Yoland Yan
405b529545 Minor: update tests-unit README.md (#4896) 2024-09-12 04:53:08 -04:00
comfyanonymous
9d720187f1 types -> comfy_types to fix import issue. 2024-09-12 03:57:46 -04:00
Robin Huang
d247bc5a9c Expand variables in base_path for extra_config_paths.yaml. (#4893)
* Expand variables in base_path for extra_config_paths.yaml.

* Fix comments.
2024-09-12 01:52:06 -04:00
comfyanonymous
9f4daca9d9 Doesn't really make sense for cfg_pp sampler to call regular one. 2024-09-11 02:51:36 -04:00
yoinked
b5d0f2a908 Add CFG++ to DPM++ 2S Ancestral (#3871)
* Update sampling.py

* Update samplers.py

* my bad

* "fix" the sampler

* Update samplers.py

* i named it wrong

* minor sampling improvements

mainly using a dynamic rho value (hey this sounds a lot like smea!!!)

* revert rho change

rho? r? its just 1/2
2024-09-11 02:49:44 -04:00
bymyself
e760bf5c40 Add content-type filter method to folder_paths (#4054)
* Add content-type filter method to folder_paths

* Add unit tests

* Hardcode webp content-type

* Annotate content_types as Literal["image", "video", "audio"]
2024-09-11 02:00:07 -04:00
comfyanonymous
36c83cdbba Limit origin check to when host is loopback.
This should still prevent the exploit without breaking things for people
who use reverse proxies.
2024-09-11 01:06:37 -04:00
Yoland Yan
81778a7feb [🗻 Mount Fuji Commit] Add unit tests for folder path utilities (#4869)
All past 30 min of comtts are done on the top of Mt Fuji
By Comfy, Robin, and Yoland
All other comfy org members died on the way

Introduced unit tests to verify the correctness of various folder path
utility functions such as `get_directory_by_type`, `annotated_filepath`,
and `recursive_search` among others. These tests cover scenarios
including directory retrieval, filepath annotation, recursive file
searches, and filtering files by extensions, enhancing the robustness
and reliability of the codebase.
2024-09-10 00:44:49 -04:00
comfyanonymous
bc94662b31 Cleanup. 2024-09-10 00:43:37 -04:00
Robin Huang
9fa8faa44a Expand user directory for basepath in extra_models_paths.yaml (#4857)
* Expand user path.

* Add test.

* Add unit test for expanding base path.

* Simplify unit test.

* Remove comment.

* Remove comment.

* Checkpoints.

* Refactor.
2024-09-10 00:33:44 -04:00
comfyanonymous
9a7444e39f Add diffusion_models to the extra_model_paths.yaml.example 2024-09-10 00:21:33 -04:00
comfyanonymous
54fca4a218 If host does not contain a port only compare the hostnames. 2024-09-09 16:28:23 -04:00
Chenlei Hu
cd4955367e Add back CI action for tests-ui (#4859) 2024-09-09 04:32:55 -04:00
david02871
8354203d95 Add .venv to gitignore (#4756) 2024-09-09 04:31:18 -04:00
comfyanonymous
e0b41243b4 Fix issue where sometimes origin doesn't contain the port. 2024-09-09 03:18:17 -04:00
Alex "mcmonkey" Goodwin
619263d4a6 allow current timestamp in save image prefix (#4030) 2024-09-09 02:55:51 -04:00
comfyanonymous
e3b0402bb7 Ignore origin domain when it's empty. 2024-09-09 01:04:56 -04:00
Darion
967867d48c fix: url decode filename from API (#4801) 2024-09-08 21:02:32 -04:00
comfyanonymous
cbaac71bf5 Fix issue with last commit. 2024-09-08 19:35:23 -04:00
comfyanonymous
3ab3516e46 By default only accept requests where origin header matches the host.
Browsers are dumb and let any website do requests to localhost this should
prevent this without breaking things. CORS prevents the javascript from
reading the response but they can still write it.

At the moment this is only enabled when the --enable-cors-header argument
is not used.
2024-09-08 18:17:29 -04:00
comfyanonymous
9c5fca75f4 Fix lora issue. 2024-09-08 10:10:47 -04:00
guill
a5da4d0b3e Fix error with ExecutionBlocker and OUTPUT_IS_LIST (#4836)
This change resolves an error when a node with OUTPUT_IS_LIST=(True,)
receives an ExecutionBlocker. I've also added a unit test for this case.
2024-09-08 09:48:47 -04:00
comfyanonymous
32a60a7bac Support onetrainer text encoder Flux lora. 2024-09-08 09:31:41 -04:00
Jim Winkens
bb52934ba4 Fix import issue (#4815) 2024-09-07 05:28:32 -04:00
comfyanonymous
8aabd7c8c0 SaveLora node can now save "full diff" lora format.
This isn't actually a lora format and is saving the full diff of the
weights in a format that can be used in the lora loader nodes.
2024-09-07 03:21:02 -04:00
comfyanonymous
a09b29ca11 Add an option to the SaveLora node to store the bias diff. 2024-09-07 03:03:30 -04:00
comfyanonymous
9bfee68773 LoraSave node now supports generating text encoder loras.
text_encoder_diff should be connected to a CLIPMergeSubtract node.

model_diff and text_encoder_diff are optional inputs so you can create
model only loras, text encoder only loras or a lora that contains both.
2024-09-07 02:30:12 -04:00
comfyanonymous
ea77750759 Support a generic Comfy format for text encoder loras.
This is a format with keys like:
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.lora_up.weight

Instead of waiting for me to add support for specific lora formats you can
convert your text encoder loras to this format instead.

If you want to see an example save a text encoder lora with the SaveLora
node with the commit right after this one.
2024-09-07 02:20:39 -04:00
comfyanonymous
c27ebeb1c2 Fix onnx export not working on flux. 2024-09-06 03:21:52 -04:00
guill
0c7c98a965 Nodes using UNIQUE_ID as input are NOT_IDEMPOTENT (#4793)
As suggested by @ltdrdata, we can automatically consider nodes that take
the UNIQUE_ID hidden input to be NOT_IDEMPOTENT.
2024-09-05 19:33:02 -04:00
comfyanonymous
dc2eb75b85 Update stable release workflow to latest pytorch with cuda 12.4. 2024-09-05 19:21:52 -04:00
Chenlei Hu
fa34efe3bd Update frontend to v1.2.47 (#4798)
* Update web content to release v1.2.47

* Update shortcut list
2024-09-05 18:56:01 -04:00
comfyanonymous
5cbaa9e07c Mistoline flux controlnet support. 2024-09-05 00:05:17 -04:00
comfyanonymous
c7427375ee Prioritize freeing partially offloaded models first. 2024-09-04 19:47:32 -04:00
comfyanonymous
22d1241a50 Add an experimental LoraSave node to extract model loras.
The model_diff input should be connected to the output of a
ModelMergeSubtract node.
2024-09-04 16:38:38 -04:00
Jedrzej Kosinski
f04229b84d Add emb_patch support to UNetModel forward (#4779) 2024-09-04 14:35:15 -04:00
Silver
f067ad15d1 Make live preview size a configurable launch argument (#4649)
* Make live preview size a configurable launch argument

* Remove import from testing phase

* Update cli_args.py
2024-09-03 19:16:38 -04:00
comfyanonymous
483004dd1d Support newer glora format. 2024-09-03 17:02:19 -04:00
comfyanonymous
00a5d08103 Lower fp8 lora memory usage. 2024-09-03 01:25:05 -04:00
comfyanonymous
d043997d30 Flux onetrainer lora. 2024-09-02 08:22:15 -04:00
Alex "mcmonkey" Goodwin
f1c2301697 fix typo in stale-issues (#4735) 2024-09-01 17:44:49 -04:00
comfyanonymous
8d31a6632f Speed up inference on nvidia 10 series on Linux. 2024-09-01 17:29:31 -04:00
comfyanonymous
b643eae08b Make minimum_inference_memory() depend on --reserve-vram 2024-09-01 01:18:34 -04:00
comfyanonymous
baa6b4dc36 Update manual install instructions. 2024-08-31 04:37:23 -04:00
Alex "mcmonkey" Goodwin
d4aeefc297 add github action to automatically handle stale user support issues (#4683)
* add github action to automatically handle stale user support issues

* improve stale message

* remove token part
2024-08-31 01:57:18 -04:00
comfyanonymous
587e7ca654 Remove github buttons. 2024-08-31 01:53:10 -04:00
Chenlei Hu
c90459eba0 Update ComfyUI_frontend to 1.2.40 (#4691)
* Update ComfyUI_frontend to 1.2.40

* Add files
2024-08-30 19:32:10 -04:00
Vedat Baday
04278afb10 feat: return import_failed from init_extra_nodes function (#4694) 2024-08-30 19:26:47 -04:00
comfyanonymous
935ae153e1 Cleanup. 2024-08-30 12:53:59 -04:00
Chenlei Hu
e91662e784 Get logs endpoint & system_stats additions (#4690)
* Add route for getting output logs

* Include ComfyUI version

* Move to own function

* Changed to memory logger

* Unify logger setup logic

* Fix get version git fallback

---------

Co-authored-by: pythongosssss <125205205+pythongosssss@users.noreply.github.com>
2024-08-30 12:46:37 -04:00
comfyanonymous
63fafaef45 Fix potential issue with hydit controlnets. 2024-08-30 04:58:41 -04:00
Alex "mcmonkey" Goodwin
ec28cd9136 swap legacy sdv15 link (#4682)
* swap legacy sdv15 link

* swap v15 ckpt examples to safetensors

* link the fp16 copy of the model by default
2024-08-29 19:48:48 -04:00
comfyanonymous
6eb5d64522 Fix glora lowvram issue. 2024-08-29 19:07:23 -04:00
comfyanonymous
10a79e9898 Implement model part of flux union controlnet. 2024-08-29 18:41:22 -04:00
comfyanonymous
ea3f39bd69 InstantX depth flux controlnet. 2024-08-29 02:14:19 -04:00
comfyanonymous
b33cd61070 InstantX canny controlnet. 2024-08-28 19:02:50 -04:00
Dr.Lt.Data
34eda0f853 fix: remove redundant useless loop (#4656)
fix: potential error of undefined variable

https://github.com/comfyanonymous/ComfyUI/discussions/4650
2024-08-28 17:46:30 -04:00
comfyanonymous
d31e226650 Unify RMSNorm code. 2024-08-28 16:56:38 -04:00
comfyanonymous
b79fd7d92c ComfyUI supports more than just stable diffusion. 2024-08-28 16:12:24 -04:00
comfyanonymous
38c22e631a Fix case where model was not properly unloaded in merging workflows. 2024-08-27 19:03:51 -04:00
Chenlei Hu
6bbdcd28ae Support weight padding on diff weight patch (#4576) 2024-08-27 13:55:37 -04:00
comfyanonymous
ab130001a8 Do RMSNorm in native type. 2024-08-27 02:41:56 -04:00
Chenlei Hu
ca4b8f30e0 Cleanup empty dir if frontend zip download failed (#4574) 2024-08-27 02:07:25 -04:00
Robin Huang
70b84058c1 Add relative file path to the progress report. (#4621) 2024-08-27 02:06:12 -04:00
comfyanonymous
2ca8f6e23d Make the stochastic fp8 rounding reproducible. 2024-08-26 15:12:06 -04:00
comfyanonymous
7985ff88b9 Use less memory in float8 lora patching by doing calculations in fp16. 2024-08-26 14:45:58 -04:00
comfyanonymous
c6812947e9 Fix potential memory leak. 2024-08-26 02:07:32 -04:00
comfyanonymous
9230f65823 Fix some controlnets OOMing when loading. 2024-08-25 05:54:29 -04:00
guill
6ab1e6fd4a [Bug #4529] Fix graph partial validation failure (#4588)
Currently, if a graph partially fails validation (i.e. some outputs are
valid while others have links from missing nodes), the execution loop
could get an exception resulting in server lockup.

This isn't actually possible to reproduce via the default UI, but is a
potential issue for people using the API to construct invalid graphs.
2024-08-24 15:34:58 -04:00
comfyanonymous
07dcbc3a3e Clarify how to use high quality previews. 2024-08-24 02:31:03 -04:00
comfyanonymous
8ae23d8e80 Fix onnx export. 2024-08-23 17:52:47 -04:00
comfyanonymous
7df42b9a23 Fix dora. 2024-08-23 04:58:59 -04:00
comfyanonymous
5d8bbb7281 Cleanup. 2024-08-23 04:06:27 -04:00
comfyanonymous
2c1d2375d6 Fix. 2024-08-23 04:04:55 -04:00
Simon Lui
64ccb3c7e3 Rework IPEX check for future inclusion of XPU into Pytorch upstream and do a bit more optimization of ipex.optimize(). (#4562) 2024-08-23 03:59:57 -04:00
Scorpinaus
9465b23432 Added SD15_Inpaint_Diffusers model support for unet_config_from_diffusers_unet function (#4565) 2024-08-23 03:57:08 -04:00
Chenlei Hu
bb4416dd5b Fix task.status.status_str caused by #2666 (#4551)
* Fix task.status.status_str caused by 2666 regression

* fix

* fix
2024-08-22 17:38:30 -04:00
comfyanonymous
c0b0da264b Missing imports. 2024-08-22 17:20:51 -04:00
comfyanonymous
c26ca27207 Move calculate function to comfy.lora 2024-08-22 17:12:00 -04:00
comfyanonymous
7c6bb84016 Code cleanups. 2024-08-22 17:05:12 -04:00
comfyanonymous
c54d3ed5e6 Fix issue with models staying loaded in memory. 2024-08-22 15:58:20 -04:00
comfyanonymous
c7ee4b37a1 Try to fix some lora issues. 2024-08-22 15:32:18 -04:00
David
7b70b266d8 Generalize MacOS version check for force-upcast-attention (#4548)
This code automatically forces upcasting attention for MacOS versions 14.5 and 14.6. My computer returns the string "14.6.1" for `platform.mac_ver()[0]`, so this generalizes the comparison to catch more versions.

I am running MacOS Sonoma 14.6.1 (latest version) and was seeing black image generation on previously functional workflows after recent software updates. This PR solved the issue for me.

See comfyanonymous/ComfyUI#3521
2024-08-22 13:24:21 -04:00
comfyanonymous
8f60d093ba Fix issue. 2024-08-22 10:38:24 -04:00
guill
dafbe321d2 Fix a bug where cached outputs affected IS_CHANGED (#4535)
This change fixes a bug where non-constant values could be passed to the
IS_CHANGED function. This would result in workflows taking an extra
execution before they acted as if they were cached.

The actual change is like 4 characters -- the rest is adding unit tests.
2024-08-21 23:38:46 -04:00
comfyanonymous
5f84ea63e8 Add a shortcut to the nightly package to run with --fast. 2024-08-21 23:36:58 -04:00
comfyanonymous
843a7ff70c fp16 is actually faster than fp32 on a GTX 1080. 2024-08-21 23:23:50 -04:00
comfyanonymous
a60620dcea Fix slow performance on 10 series Nvidia GPUs. 2024-08-21 16:39:02 -04:00
comfyanonymous
015f73dc49 Try a different type of flux fp16 fix. 2024-08-21 16:17:15 -04:00
comfyanonymous
904bf58e7d Make --fast work on pytorch nightly. 2024-08-21 14:01:41 -04:00
Svein Ove Aas
5f50263088 Replace use of .view with .reshape (#4522)
When generating images with fp8_e4_m3 Flux and batch size >1, using --fast, ComfyUI throws a "view size is not compatible with input tensor's size and stride" error pointing at the first of these two calls to view.

As reshape is semantically equivalent to view except for working on a broader set of inputs, there should be no downside to changing this. The only difference is that it clones the underlying data in cases where .view would error out. I have confirmed that the output still looks as expected, but cannot confirm that no mutable use is made of the tensors anywhere.

Note that --fast is only marginally faster than the default.
2024-08-21 11:21:48 -04:00
Alex "mcmonkey" Goodwin
5e806f555d add a get models list api route (#4519)
* get models list api route

* remove copypasta
2024-08-21 02:04:42 -04:00
Robin Huang
f07e5bb522 Add GET /internal/files. (#4295)
* Create internal route table.

* List files.

* Add GET /internal/files.

Retrieves list of files in models, output, and user directories.

* Refactor file names.

* Use typing_extensions for Python 3.8

* Fix tests.

* Remove print statements.

* Update README.

* Add output and user to valid directory test.

* Add missing type hints.
2024-08-21 01:25:06 -04:00
comfyanonymous
03ec517afb Remove useless line, adjust windows default reserved vram. 2024-08-21 00:47:19 -04:00
Chenlei Hu
f257fc999f Add optional deprecated/experimental flag to node class (#4506)
* Add optional deprecated flag to node class

* nit

* Add experimental flag
2024-08-21 00:01:34 -04:00
Chenlei Hu
bb50e69839 Update frontend to 1.2.30 (#4513) 2024-08-21 00:00:49 -04:00
comfyanonymous
510f3438c1 Speed up fp8 matrix mult by using better code. 2024-08-20 22:53:26 -04:00
comfyanonymous
ea63b1c092 Simpletrainer lycoris format. 2024-08-20 12:05:13 -04:00
comfyanonymous
9953f22fce Add --fast argument to enable experimental optimizations.
Optimizations that might break things/lower quality will be put behind
this flag first and might be enabled by default in the future.

Currently the only optimization is float8_e4m3fn matrix multiplication on
4000/ADA series Nvidia cards or later. If you have one of these cards you
will see a speed boost when using fp8_e4m3fn flux for example.
2024-08-20 11:55:51 -04:00
comfyanonymous
d1a6bd6845 Support loading long clipl model with the CLIP loader node. 2024-08-20 10:46:36 -04:00
comfyanonymous
83dbac28eb Properly set if clip text pooled projection instead of using hack. 2024-08-20 10:46:36 -04:00
comfyanonymous
538cb068bc Make cast_to a nop if weight is already good. 2024-08-20 10:46:36 -04:00
comfyanonymous
1b3eee672c Fix potential issue with multi devices. 2024-08-20 10:46:36 -04:00
Chenlei Hu
5a69f84c3c Update README.md (Add shield badges) (#4490) 2024-08-19 18:25:20 -04:00
comfyanonymous
9eee470244 New load_text_encoder_state_dicts function.
Now you can load text encoders straight from a list of state dicts.
2024-08-19 17:36:35 -04:00
comfyanonymous
045377ea89 Add a --reserve-vram argument if you don't want comfy to use all of it.
--reserve-vram 1.0 for example will make ComfyUI try to keep 1GB vram free.

This can also be useful if workflows are failing because of OOM errors but
in that case please report it if --reserve-vram improves your situation.
2024-08-19 17:16:18 -04:00
comfyanonymous
4d341b78e8 Bug fixes. 2024-08-19 16:28:55 -04:00
comfyanonymous
6138f92084 Use better dtype for the lowvram lora system. 2024-08-19 15:35:25 -04:00
comfyanonymous
be0726c1ed Remove duplication. 2024-08-19 15:26:50 -04:00
comfyanonymous
766ae119a8 CheckpointSave node name. 2024-08-19 15:06:12 -04:00
Yoland Yan
fc90ceb6ba Update issue template config.yml to direct frontend issues to frontend repos (#4486)
* Update config.yml

* Typos
2024-08-19 13:41:30 -04:00
comfyanonymous
4506ddc86a Better subnormal fp8 stochastic rounding. Thanks Ashen. 2024-08-19 13:38:03 -04:00
comfyanonymous
20ace7c853 Code cleanup. 2024-08-19 12:48:59 -04:00
Chenlei Hu
b29b3b86c5 Update README to include frontend section (#4468)
* Update README to include frontend section

* nit
2024-08-19 07:12:32 -04:00
comfyanonymous
22ec02afc0 Handle subnormal numbers in float8 rounding. 2024-08-19 05:51:08 -04:00
comfyanonymous
39f114c44b Less broken non blocking? 2024-08-18 16:53:17 -04:00
comfyanonymous
6730f3e1a3 Disable non blocking.
It fixed some perf issues but caused other issues that need to be debugged.
2024-08-18 14:38:09 -04:00
comfyanonymous
73332160c8 Enable non blocking transfers in lowvram mode. 2024-08-18 10:29:33 -04:00
comfyanonymous
2622c55aff Automatically use RF variant of dpmpp_2s_ancestral if RF model. 2024-08-18 00:47:25 -04:00
Ashen
1beb348ee2 dpmpp_2s_ancestral_RF for rectified flow (Flux, SD3 and Auraflow). 2024-08-18 00:33:30 -04:00
bymyself
9aa39e743c Add new shortcuts to readme (#4442) 2024-08-17 23:52:56 -04:00
comfyanonymous
d31df04c8a Indentation. 2024-08-17 23:00:44 -04:00
Xrvk
e68763f40c Add Flux model support for InstantX style controlnet residuals (#4444)
* Add Flux model support for InstantX style controlnet residuals

* Refactor Flux controlnet residual step to a separate method

* Rollback minor change

* New format for applying controlnet residuals: input->double_blocks, output->single_blocks

* Adjust XLabs Flux controlnet to fit new syntax of applying Flux controlnet residuals

* Remove unnecessary import and minor style change
2024-08-17 22:58:23 -04:00
comfyanonymous
310ad09258 Add a ModelSave node. 2024-08-17 21:43:07 -04:00
comfyanonymous
4f7a3cb6fb unet -> diffusion_models. 2024-08-17 21:31:04 -04:00
comfyanonymous
bb222ceddb Fix loras having a weak effect when applied on fp8. 2024-08-17 15:20:17 -04:00
comfyanonymous
14af129c55 Improve execution UX.
Some branches with VAELoader -> VAEDecode -> Preview were being executed
last. With this change they will be executed earlier.
2024-08-17 11:37:21 -04:00
comfyanonymous
fca42836f2 Add model_options for text encoder. 2024-08-17 11:17:20 -04:00
comfyanonymous
858d51f91a Fix VAEDecode -> Preview not being executed first. 2024-08-17 04:08:54 -04:00
comfyanonymous
cd5017c1c9 calculate_weight function to use a different dtype. 2024-08-17 01:06:08 -04:00
comfyanonymous
83f343146a Fix potential lowvram issue. 2024-08-16 17:12:42 -04:00
Chenlei Hu
b021cf67c7 Update frontend to 1.2.26 (#4415) 2024-08-16 15:25:02 -04:00
Matthew Turnshek
1770fc77ed Implement support for taef1 latent previews (#4409)
* add taef1 handling to several places

* remove guess_latent_channels and add latent_channels info directly to flux model

* remove TODO

* fix numbers
2024-08-16 12:53:13 -04:00
comfyanonymous
05a9f3faa1 Log a warning when there's an issue with IS_CHANGED. 2024-08-16 08:50:17 -04:00
comfyanonymous
86c5970ac0 Fix custom nodes hooking the map_node_over_list and breaking things. 2024-08-16 08:40:31 -04:00
Chenlei Hu
bfc214f434 Use new TS frontend uncompressed (#4379)
* Swap frontend uncompressed

* Add uncompressed files
2024-08-15 16:50:25 -04:00
comfyanonymous
3f5939add6 Tell github not to count the web directory in language stats. 2024-08-15 13:48:56 -04:00
comfyanonymous
5960f946a9 Move a few files from comfy -> comfy_execution.
Python code in the comfy folder should not import things from outside it.
2024-08-15 11:21:14 -04:00
guill
5cfe38f41c Execution Model Inversion (#2666)
* Execution Model Inversion

This PR inverts the execution model -- from recursively calling nodes to
using a topological sort of the nodes. This change allows for
modification of the node graph during execution. This allows for two
major advantages:

    1. The implementation of lazy evaluation in nodes. For example, if a
    "Mix Images" node has a mix factor of exactly 0.0, the second image
    input doesn't even need to be evaluated (and visa-versa if the mix
    factor is 1.0).

    2. Dynamic expansion of nodes. This allows for the creation of dynamic
    "node groups". Specifically, custom nodes can return subgraphs that
    replace the original node in the graph. This is an incredibly
    powerful concept. Using this functionality, it was easy to
    implement:
        a. Components (a.k.a. node groups)
        b. Flow control (i.e. while loops) via tail recursion
        c. All-in-one nodes that replicate the WebUI functionality
        d. and more
    All of those were able to be implemented entirely via custom nodes,
    so those features are *not* a part of this PR. (There are some
    front-end changes that should occur before that functionality is
    made widely available, particularly around variant sockets.)

The custom nodes associated with this PR can be found at:
https://github.com/BadCafeCode/execution-inversion-demo-comfyui

Note that some of them require that variant socket types ("*") be
enabled.

* Allow `input_info` to be of type `None`

* Handle errors (like OOM) more gracefully

* Add a command-line argument to enable variants

This allows the use of nodes that have sockets of type '*' without
applying a patch to the code.

* Fix an overly aggressive assertion.

This could happen when attempting to evaluate `IS_CHANGED` for a node
during the creation of the cache (in order to create the cache key).

* Fix Pyright warnings

* Add execution model unit tests

* Fix issue with unused literals

Behavior should now match the master branch with regard to undeclared
inputs. Undeclared inputs that are socket connections will be used while
undeclared inputs that are literals will be ignored.

* Make custom VALIDATE_INPUTS skip normal validation

Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`,
that variable will be a dictionary of the socket type of all incoming
connections. If that argument exists, normal socket type validation will
not occur. This removes the last hurdle for enabling variant types
entirely from custom nodes, so I've removed that command-line option.

I've added appropriate unit tests for these changes.

* Fix example in unit test

This wouldn't have caused any issues in the unit test, but it would have
bugged the UI if someone copy+pasted it into their own node pack.

* Use fstrings instead of '%' formatting syntax

* Use custom exception types.

* Display an error for dependency cycles

Previously, dependency cycles that were created during node expansion
would cause the application to quit (due to an uncaught exception). Now,
we'll throw a proper error to the UI. We also make an attempt to 'blame'
the most relevant node in the UI.

* Add docs on when ExecutionBlocker should be used

* Remove unused functionality

* Rename ExecutionResult.SLEEPING to PENDING

* Remove superfluous function parameter

* Pass None for uneval inputs instead of default

This applies to `VALIDATE_INPUTS`, `check_lazy_status`, and lazy values
in evaluation functions.

* Add a test for mixed node expansion

This test ensures that a node that returns a combination of expanded
subgraphs and literal values functions correctly.

* Raise exception for bad get_node calls.

* Minor refactor of IsChangedCache.get

* Refactor `map_node_over_list` function

* Fix ui output for duplicated nodes

* Add documentation on `check_lazy_status`

* Add file for execution model unit tests

* Clean up Javascript code as per review

* Improve documentation

Converted some comments to docstrings as per review

* Add a new unit test for mixed lazy results

This test validates that when an output list is fed to a lazy node, the
node will properly evaluate previous nodes that are needed by any inputs
to the lazy node.

No code in the execution model has been changed. The test already
passes.

* Allow kwargs in VALIDATE_INPUTS functions

When kwargs are used, validation is skipped for all inputs as if they
had been mentioned explicitly.

* List cached nodes in `execution_cached` message

This was previously just bugged in this PR.
2024-08-15 11:21:11 -04:00
comfyanonymous
0f9c2a7822 Try to fix SDXL OOM issue on some configurations. 2024-08-14 23:08:54 -04:00
comfyanonymous
153d0a8142 Add a update/update_comfyui_stable.bat to the standalones. 2024-08-14 22:29:23 -04:00
Chenlei Hu
ab4dd19b91 Remove legacy ui test files (#4316) 2024-08-14 21:01:06 -04:00
comfyanonymous
f1d6cef71c Revert "Disable cuda malloc by default."
This reverts commit 50bf66e5c4.
2024-08-14 08:38:07 -04:00
comfyanonymous
33fb282d5c Fix issue. 2024-08-14 02:51:47 -04:00
comfyanonymous
50bf66e5c4 Disable cuda malloc by default. 2024-08-14 02:49:25 -04:00
pythongosssss
e60e19b175 Add support for simple tooltips (#3842)
* Add support for simple tooltips

* Fix overflow

* Add tooltips for nodes in the default workflow

* new line

* Prevent potential crash

* PR feedback

* Hide tooltip when clicking (e.g. combo widget)

* Refactor tooltips, add node level support

* Fix

* move

* Fix test (and undo last change)

* Fixed indent

* Fix dom widgets, dont show tooltip if not over canvas
2024-08-14 01:22:10 -04:00
comfyanonymous
a5af64d3ce Revert "Not sure if this actually changes anything but it can't hurt."
This reverts commit 34608de2e9.
2024-08-14 01:05:17 -04:00
Robin Huang
3e52e0364c Add model downloading endpoint. (#4248)
* Add model downloading endpoint.

* Move client session init to async function.

* Break up large function.

* Send "download_progress" as websocket event.

* Fixed

* Fixed.

* Use async mock.

* Move server set up to right before run call.

* Validate that model subdirectory cannot contain relative paths.

* Add download_model test checking for invalid paths.

* Remove DS_Store.

* Consolidate DownloadStatus and DownloadModelResult

* Add progress_interval as an optional parameter.

* Use tuple type from annotations.

* Use pydantic.

* Update comment.

* Revert "Use pydantic."

This reverts commit 7461e8eb00.

* Add new line.

* Add newline EOF.

* Validate model filename as well.

* Add comment to not reply on internal.

* Restrict downloading to safetensor files only.
2024-08-13 15:48:52 -04:00
comfyanonymous
34608de2e9 Not sure if this actually changes anything but it can't hurt. 2024-08-13 13:29:16 -04:00
comfyanonymous
39fb74c5bd Fix bug when model cannot be partially unloaded. 2024-08-13 03:57:55 -04:00
comfyanonymous
74e124f4d7 Fix some issues with TE being in lowvram mode. 2024-08-12 23:42:21 -04:00
comfyanonymous
a562c17e8a load_unet -> load_diffusion_model with a model_options argument. 2024-08-12 23:20:57 -04:00
comfyanonymous
5942c17d55 Order of operations matters. 2024-08-12 21:56:18 -04:00
comfyanonymous
c032b11e07 xlabs Flux controlnet implementation. (#4260)
* xlabs Flux controlnet.

* Fix not working on old python.

* Remove comment.
2024-08-12 21:22:22 -04:00
comfyanonymous
b8ffb2937f Memory tweaks. 2024-08-12 15:07:11 -04:00
Vladimir Semyonov
ce37c11164 add DS_Store to gitignore (#4324) 2024-08-12 12:32:34 -04:00
Alex "mcmonkey" Goodwin
b5c3906b38 Automatically link the Comfy CI page on PRs (#4326)
also use_prior_commit so it doesn't get a janked merge commit instead of the real one
2024-08-12 12:32:16 -04:00
comfyanonymous
5d43e75e5b Fix some issues with the model sometimes not getting patched. 2024-08-12 12:27:54 -04:00
comfyanonymous
517f4a94e4 Fix some lora loading slowdowns. 2024-08-12 11:50:32 -04:00
comfyanonymous
52a471c5c7 Change name of log. 2024-08-12 10:35:06 -04:00
comfyanonymous
ad76574cb8 Fix some potential issues with the previous commits. 2024-08-12 00:23:29 -04:00
comfyanonymous
9acfe4df41 Support loading directly to vram with CLIPLoader node. 2024-08-12 00:06:01 -04:00
comfyanonymous
9829b013ea Fix mistake in last commit. 2024-08-12 00:00:17 -04:00
comfyanonymous
5c69cde037 Load TE model straight to vram if certain conditions are met. 2024-08-11 23:52:43 -04:00
comfyanonymous
e9589d6d92 Add a way to set model dtype and ops from load_checkpoint_guess_config. 2024-08-11 08:50:34 -04:00
comfyanonymous
0d82a798a5 Remove the ckpt_path from load_state_dict_guess_config. 2024-08-11 08:37:35 -04:00
ljleb
925fff26fd alternative to load_checkpoint_guess_config that accepts a loaded state dict (#4249)
* make alternative fn

* add back ckpt path as 2nd argument?
2024-08-11 08:36:52 -04:00
comfyanonymous
75b9b55b22 Fix issues with #4302 and support loading diffusers format flux. 2024-08-10 21:28:24 -04:00
Jaret Burkett
1765f1c60c FLUX: Added full diffusers mapping for FLUX.1 schnell and dev. Adds full LoRA support from diffusers LoRAs. (#4302) 2024-08-10 21:26:41 -04:00
comfyanonymous
1de69fe4d5 Fix some issues with inference slowing down. 2024-08-10 16:21:25 -04:00
comfyanonymous
ae197f651b Speed up hunyuan dit inference a bit. 2024-08-10 07:36:27 -04:00
comfyanonymous
1b5b8ca81a Fix regression. 2024-08-09 21:45:21 -04:00
comfyanonymous
6678d5cf65 Fix regression. 2024-08-09 14:02:38 -04:00
TTPlanetPig
e172564eea Update controlnet.py to fix the default controlnet weight as constant (#4285) 2024-08-09 13:40:05 -04:00
comfyanonymous
a3cc326748 Better fix for lowvram issue. 2024-08-09 12:16:25 -04:00
comfyanonymous
86a97e91fc Fix controlnet regression. 2024-08-09 12:08:58 -04:00
comfyanonymous
5acdadc9f3 Fix issue with some lowvram weights. 2024-08-09 03:58:28 -04:00
comfyanonymous
55ad9d5f8c Fix regression. 2024-08-09 03:36:40 -04:00
comfyanonymous
a9f04edc58 Implement text encoder part of HunyuanDiT loras. 2024-08-09 03:21:10 -04:00
comfyanonymous
a475ec2300 Cleanup HunyuanDit controlnets.
Use the: ControlNetApply SD3 and HunyuanDiT node.
2024-08-09 02:59:34 -04:00
来新璐
06eb9fb426 feat: add support for HunYuanDit ControlNet (#4245)
* add support for HunYuanDit ControlNet

* fix hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix code format style

* add control_weight support for HunyuanDit Controlnet

* use control_weights in HunyuanDit Controlnet

* fix typo
2024-08-09 02:59:24 -04:00
comfyanonymous
413322645e Raw torch is faster than einops? 2024-08-08 22:09:29 -04:00
comfyanonymous
11200de970 Cleaner code. 2024-08-08 20:07:09 -04:00
comfyanonymous
037c38eb0f Try to improve inference speed on some machines. 2024-08-08 17:29:27 -04:00
comfyanonymous
1e11d2d1f5 Better prints. 2024-08-08 17:29:27 -04:00
Alex "mcmonkey" Goodwin
65ea6be38f PullRequest CI Run: use pull_request_target to allow the CI Dashboard to work (#4277)
'_target' allows secrets to pass through, and we're just using the secret that allows uploading to the dashboard and are manually vetting PRs before running this workflow anyway
2024-08-08 17:20:48 -04:00
Alex "mcmonkey" Goodwin
5df6f57b5d minor fix on copypasta action name (#4276)
my bad sorry
2024-08-08 16:30:59 -04:00
Alex "mcmonkey" Goodwin
6588bfdef9 add GitHub workflow for CI tests of PRs (#4275)
When the 'Run-CI-Test' label is added to a PR, it will be tested by the CI, on a small matrix of stable versions.
2024-08-08 16:24:49 -04:00
Alex "mcmonkey" Goodwin
50ed2879ef Add full CI test matrix GitHub Workflow (#4274)
automatically runs a matrix of full GPU-enabled tests on all new commits to the ComfyUI master branch
2024-08-08 15:40:07 -04:00
comfyanonymous
66d4233210 Fix. 2024-08-08 15:16:51 -04:00
comfyanonymous
591010b7ef Support diffusers text attention flux loras. 2024-08-08 14:45:52 -04:00
comfyanonymous
08f92d55e9 Partial model shift support. 2024-08-08 14:45:06 -04:00
comfyanonymous
8115d8cce9 Add Flux fp16 support hack. 2024-08-07 15:08:39 -04:00
comfyanonymous
6969fc9ba4 Make supported_dtypes a priority list. 2024-08-07 15:00:06 -04:00
comfyanonymous
cb7c4b4be3 Workaround for lora OOM on lowvram mode. 2024-08-07 14:30:54 -04:00
comfyanonymous
1208863eca Fix "Comfy" lora keys.
They are in this format now:
diffusion_model.full.model.key.name.lora_up.weight
2024-08-07 13:49:31 -04:00
comfyanonymous
e1c528196e Fix bundled embed. 2024-08-07 13:30:45 -04:00
comfyanonymous
17030fd4c0 Support for "Comfy" lora format.
The keys are just: model.full.model.key.name.lora_up.weight

It is supported by all comfyui supported models.

Now people can just convert loras to this format instead of having to ask
for me to implement them.
2024-08-07 13:18:32 -04:00
comfyanonymous
c19dcd362f Controlnet code refactor. 2024-08-07 12:59:28 -04:00
comfyanonymous
1c08bf35b4 Support format for embeddings bundled in loras. 2024-08-07 03:45:25 -04:00
PhilWun
2a02546e20 Add type hints to folder_paths.py (#4191)
* add type hints to folder_paths.py

* replace deprecated standard collections type hints

* fix type error when using Python 3.8
2024-08-06 21:59:34 -04:00
359 changed files with 660257 additions and 45242 deletions

View File

@@ -33,12 +33,12 @@ def pull(repo, remote_name='origin', branch='master'):
user = repo.default_signature
tree = repo.index.write_tree()
commit = repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging.
repo.state_cleanup()
else:
@@ -75,6 +75,25 @@ else:
print("pulling latest changes")
pull(repo)
if "--stable" in sys.argv:
def latest_tag(repo):
versions = []
for k in repo.references:
try:
prefix = "refs/tags/v"
if k.startswith(prefix):
version = list(map(int, k[len(prefix):].split(".")))
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
except:
pass
versions.sort()
if len(versions) > 0:
return versions[-1][1]
return None
latest_tag = latest_tag(repo)
if latest_tag is not None:
repo.checkout(latest_tag)
print("Done!")
self_update = True
@@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
shutil.copy(repo_req_path, req_path)
except:
pass
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
try:
if not file_size(stable_update_script_to) > 10:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass

View File

@@ -0,0 +1,8 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
if exist update_new.py (
move /y update_new.py update.py
echo Running updater again since it got updated.
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
)
if "%~1"=="" pause

View File

@@ -14,7 +14,7 @@ run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
RECOMMENDED WAY TO UPDATE:

View File

@@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
pause

2
.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
/web/assets/** linguist-generated
/web/** linguist-vendored

View File

@@ -1,5 +1,8 @@
blank_issues_enabled: true
contact_links:
- name: ComfyUI Frontend Issues
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
- name: ComfyUI Matrix Space
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).

View File

@@ -0,0 +1,53 @@
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Pull Request CI Workflow Runs
on:
pull_request_target:
types: [labeled]
jobs:
pr-test-stable:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
use_prior_commit: 'true'
comment:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- uses: actions/github-script@v6
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
})

View File

@@ -3,8 +3,8 @@ name: Python Linting
on: [push, pull_request]
jobs:
pylint:
name: Run Pylint
ruff:
name: Run Ruff
runs-on: ubuntu-latest
steps:
@@ -16,8 +16,8 @@ jobs:
with:
python-version: 3.x
- name: Install Pylint
run: pip install pylint
- name: Install Ruff
run: pip install ruff
- name: Run Pylint
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
- name: Run Ruff
run: ruff check .

View File

@@ -12,17 +12,17 @@ on:
description: 'CUDA version'
required: true
type: string
default: "121"
default: "124"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "9"
default: "7"
jobs:

21
.github/workflows/stale-issues.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: 'Close stale issues'
on:
schedule:
# Run daily at 430 am PT
- cron: '30 11 * * *'
permissions:
issues: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
days-before-stale: 30
days-before-close: 7
stale-issue-label: 'Stale'
only-labels: 'User Support'
exempt-all-assignees: true
exempt-all-milestones: true

View File

@@ -1,76 +0,0 @@
# This is a temporary action during frontend TS migration.
# This file should be removed after TS migration is completed.
# The browser test is here to ensure TS repo is working the same way as the
# current JS code.
# If you are adding UI feature, please sync your changes to the TS repo:
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
name: Playwright Browser Tests CI
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout ComfyUI
uses: actions/checkout@v4
with:
repository: "comfyanonymous/ComfyUI"
path: "ComfyUI"
- name: Checkout ComfyUI_frontend
uses: actions/checkout@v4
with:
repository: "huchenlei/ComfyUI_frontend"
path: "ComfyUI_frontend"
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
- uses: actions/setup-node@v3
with:
node-version: lts/*
- uses: actions/setup-python@v4
with:
python-version: '3.8'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install wait-for-it
working-directory: ComfyUI
- name: Start ComfyUI server
run: |
python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 600
working-directory: ComfyUI
- name: Install ComfyUI_frontend dependencies
run: |
npm ci
working-directory: ComfyUI_frontend
- name: Install Playwright Browsers
run: npx playwright install --with-deps
working-directory: ComfyUI_frontend
- name: Run Playwright tests
run: npx playwright test
working-directory: ComfyUI_frontend
- name: Check for unhandled exceptions in server log
run: |
if grep -qE "Exception|Error" console_output.log; then
echo "Unhandled exception/error found in server log."
exit 1
fi
working-directory: ComfyUI
- uses: actions/upload-artifact@v4
if: always()
with:
name: playwright-report
path: ComfyUI_frontend/playwright-report/
retention-days: 30
- uses: actions/upload-artifact@v4
if: always()
with:
name: console-output
path: ComfyUI/console_output.log
retention-days: 30

96
.github/workflows/test-ci.yml vendored Normal file
View File

@@ -0,0 +1,96 @@
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Full Comfy CI Workflow Runs
on:
push:
branches:
- master
paths-ignore:
- 'app/**'
- 'input/**'
- 'output/**'
- 'notebooks/**'
- 'script_examples/**'
- '.github/**'
- 'web/**'
workflow_dispatch:
jobs:
test-stable:
strategy:
fail-fast: false
matrix:
# os: [macos, linux, windows]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
# test-win-nightly:
# strategy:
# fail-fast: true
# matrix:
# os: [windows]
# python_version: ["3.9", "3.10", "3.11", "3.12"]
# cuda_version: ["12.1"]
# torch_version: ["nightly"]
# include:
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
# runs-on: ${{ matrix.runner_label }}
# steps:
# - name: Test Workflows
# uses: comfy-org/comfy-action@main
# with:
# os: ${{ matrix.os }}
# python_version: ${{ matrix.python_version }}
# torch_version: ${{ matrix.torch_version }}
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
# comfyui_flags: ${{ matrix.flags }}
test-unix-nightly:
strategy:
fail-fast: false
matrix:
os: [macos, linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}

45
.github/workflows/test-launch.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: Test server launches without errors
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout ComfyUI
uses: actions/checkout@v4
with:
repository: "comfyanonymous/ComfyUI"
path: "ComfyUI"
- uses: actions/setup-python@v4
with:
python-version: '3.8'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install wait-for-it
working-directory: ComfyUI
- name: Start ComfyUI server
run: |
python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 30
working-directory: ComfyUI
- name: Check for unhandled exceptions in server log
run: |
if grep -qE "Exception|Error" console_output.log; then
echo "Unhandled exception/error found in server log."
exit 1
fi
working-directory: ComfyUI
- uses: actions/upload-artifact@v4
if: always()
with:
name: console-output
path: ComfyUI/console_output.log
retention-days: 30

View File

@@ -1,29 +1,29 @@
name: Tests CI
name: Unit Tests
on: [push, pull_request]
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt

View File

@@ -12,7 +12,7 @@ on:
description: 'extra dependencies'
required: false
type: string
default: "\"numpy<2\""
default: ""
cu:
description: 'cuda version'
required: true
@@ -23,13 +23,13 @@ on:
description: 'python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "9"
default: "7"
# push:
# branches:
# - master

View File

@@ -67,6 +67,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2

View File

@@ -13,13 +13,13 @@ on:
description: 'python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "9"
default: "7"
# push:
# branches:
# - master

4
.gitignore vendored
View File

@@ -12,10 +12,12 @@ extra_model_paths.yaml
.vscode/
.idea/
venv/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example
!/web/extensions/core/
/tests-ui/data/object_info.json
/user/
*.log
web_custom_versions/
web_custom_versions/
.DS_Store

View File

@@ -1,3 +0,0 @@
[MESSAGES CONTROL]
disable=all
enable=eval-used

View File

@@ -1 +1,23 @@
* @comfyanonymous
# Admins
* @comfyanonymous
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink

155
README.md
View File

@@ -1,8 +1,35 @@
ComfyUI
=======
The most powerful and modular stable diffusion GUI and backend.
-----------
![ComfyUI Screenshot](comfyui_screenshot.png)
<div align="center">
# ComfyUI
**The most powerful and modular diffusion model GUI and backend.**
[![Website][website-shield]][website-url]
[![Dynamic JSON Badge][discord-shield]][discord-url]
[![Matrix][matrix-shield]][matrix-url]
<br>
[![][github-release-shield]][github-release-link]
[![][github-release-date-shield]][github-release-link]
[![][github-downloads-shield]][github-downloads-link]
[![][github-downloads-latest-shield]][github-downloads-link]
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
</div>
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
@@ -12,7 +39,9 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@@ -46,32 +75,39 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| `Ctrl` + `Enter` | Queue up current graph for generation |
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| `Ctrl` + `S` | Save workflow |
| `Ctrl` + `O` | Load workflow |
| `Ctrl` + `A` | Select all nodes |
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
| `Ctrl` + `M` | Mute/unmute selected nodes |
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| `Delete`/`Backspace` | Delete selected nodes |
| `Ctrl` + `Backspace` | Delete the current graph |
| `Space` | Move the canvas around when held and moving the cursor |
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
| `Ctrl` + `D` | Load default graph |
| `Alt` + `+` | Canvas Zoom in |
| `Alt` + `-` | Canvas Zoom out |
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| `P` | Pin/Unpin selected nodes |
| `Ctrl` + `G` | Group selected nodes |
| `Q` | Toggle visibility of the queue |
| `H` | Toggle visibility of history |
| `R` | Refresh graph |
| `F` | Show/Hide menu |
| `.` | Fit view to selection (Whole graph when nothing is selected) |
| Double-Click LMB | Open node quick search palette |
| `Shift` + Drag | Move multiple wires at once |
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users
`Ctrl` can also be replaced with `Cmd` instead for macOS users
# Installing
@@ -95,6 +131,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
## Manual Install (Windows, Linux)
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
Git clone this repo.
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
@@ -105,17 +143,17 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
### NVIDIA
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
This is the command to install pytorch nightly instead which might have performance improvements:
@@ -177,6 +215,14 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
### AMD ROCm Tips
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
# Notes
Only parts of the graph that have an output with all the correct inputs will be executed.
@@ -200,7 +246,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
## How to use TLS/SSL?
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
@@ -216,6 +262,47 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
See also: [https://www.comfy.org/](https://www.comfy.org/)
## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
### Reporting Issues and Requesting Features
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
### Using the Latest Frontend
The new frontend is now the default for ComfyUI. However, please note:
1. The frontend in the main ComfyUI repository is updated weekly.
2. Daily releases are available in the separate frontend repository.
To use the most up-to-date frontend version:
1. For the latest daily release, launch ComfyUI with this command line argument:
```
--front-end-version Comfy-Org/ComfyUI_frontend@latest
```
2. For a specific version, replace `latest` with the desired version number:
```
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
```
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
```
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
```
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
# QA
### Which GPU should I buy for this?

0
api_server/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,3 @@
# ComfyUI Internal Routes
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.

View File

View File

@@ -0,0 +1,75 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
from api_server.services.terminal_service import TerminalService
import app.logger
class InternalRoutes:
'''
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
try:
file_list = self.file_service.list_files(directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@self.routes.get('/logs/raw')
async def get_raw_logs(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(app.logger.get_logs()),
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
})
@self.routes.patch('/logs/subscribe')
async def subscribe_logs(request):
json_data = await request.json()
client_id = json_data["clientId"]
enabled = json_data["enabled"]
if enabled:
self.terminal_service.subscribe(client_id)
else:
self.terminal_service.unsubscribe(client_id)
return web.Response(status=200)
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
def get_app(self):
if self._app is None:
self._app = web.Application()
self.setup_routes()
self._app.add_routes(self.routes)
return self._app

View File

View File

@@ -0,0 +1,13 @@
from typing import Dict, List, Optional
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
def list_files(self, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
return self.file_system_ops.walk_directory(directory_path)

View File

@@ -0,0 +1,60 @@
from app.logger import on_flush
import os
import shutil
class TerminalService:
def __init__(self, server):
self.server = server
self.cols = None
self.rows = None
self.subscriptions = set()
on_flush(self.send_messages)
def get_terminal_size(self):
try:
size = os.get_terminal_size()
return (size.columns, size.lines)
except OSError:
try:
size = shutil.get_terminal_size()
return (size.columns, size.lines)
except OSError:
return (80, 24) # fallback to 80x24
def update_size(self):
columns, lines = self.get_terminal_size()
changed = False
if columns != self.cols:
self.cols = columns
changed = True
if lines != self.rows:
self.rows = lines
changed = True
if changed:
return {"cols": self.cols, "rows": self.rows}
return None
def subscribe(self, client_id):
self.subscriptions.add(client_id)
def unsubscribe(self, client_id):
self.subscriptions.discard(client_id)
def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return
new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected
self.unsubscribe(client_id)
continue
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)

View File

@@ -0,0 +1,42 @@
import os
from typing import List, Union, TypedDict, Literal
from typing_extensions import TypeGuard
class FileInfo(TypedDict):
name: str
path: str
type: Literal["file"]
size: int
class DirectoryInfo(TypedDict):
name: str
path: str
type: Literal["directory"]
FileSystemItem = Union[FileInfo, DirectoryInfo]
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
return item["type"] == "file"
class FileSystemOperations:
@staticmethod
def walk_directory(directory: str) -> List[FileSystemItem]:
file_list: List[FileSystemItem] = []
for root, dirs, files in os.walk(directory):
for name in files:
file_path = os.path.join(root, name)
relative_path = os.path.relpath(file_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "file",
"size": os.path.getsize(file_path)
})
for name in dirs:
dir_path = os.path.join(root, name)
relative_path = os.path.relpath(dir_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "directory"
})
return file_list

View File

@@ -8,7 +8,7 @@ import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict
from typing import TypedDict, Optional
import requests
from typing_extensions import NotRequired
@@ -132,12 +132,13 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns:
str: The path to the initialized frontend.
@@ -150,7 +151,16 @@ class FrontendManager:
return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
if version.startswith("v"):
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
if os.path.exists(expected_path):
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
return expected_path
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v")
@@ -158,15 +168,21 @@ class FrontendManager:
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
try:
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
finally:
# Clean up the directory if it is empty, i.e. the download failed
if not os.listdir(web_root):
os.rmdir(web_root)
return web_root
@classmethod

73
app/logger.py Normal file
View File

@@ -0,0 +1,73 @@
from collections import deque
from datetime import datetime
import io
import logging
import sys
import threading
logs = None
stdout_interceptor = None
stderr_interceptor = None
class LogInterceptor(io.TextIOWrapper):
def __init__(self, stream, *args, **kwargs):
buffer = stream.buffer
encoding = stream.encoding
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
self._lock = threading.Lock()
self._flush_callbacks = []
self._logs_since_flush = []
def write(self, data):
entry = {"t": datetime.now().isoformat(), "m": data}
with self._lock:
self._logs_since_flush.append(entry)
# Simple handling for cr to overwrite the last output if it isnt a full line
# else logs just get full of progress messages
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
logs.pop()
logs.append(entry)
super().write(data)
def flush(self):
super().flush()
for cb in self._flush_callbacks:
cb(self._logs_since_flush)
self._logs_since_flush = []
def on_flush(self, callback):
self._flush_callbacks.append(callback)
def get_logs():
return logs
def on_flush(callback):
if stdout_interceptor is not None:
stdout_interceptor.on_flush(callback)
if stderr_interceptor is not None:
stderr_interceptor.on_flush(callback)
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs
if logs:
return
# Override output streams and log to buffer
logs = deque(maxlen=capacity)
global stdout_interceptor
global stderr_interceptor
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(log_level)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)

184
app/model_manager.py Normal file
View File

@@ -0,0 +1,184 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value
def clear_cache(self):
self.cache.clear()
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)
# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []
for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])
return output_list
def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)
if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
return model_file_list_cache
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()
excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False
result: list[str] = []
dirs: dict[str, float] = {}
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

View File

@@ -1,38 +1,58 @@
from __future__ import annotations
import json
import os
import re
import uuid
import glob
import shutil
import logging
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
from folder_paths import user_directory
import folder_paths
from .app_settings import AppSettings
from typing import TypedDict
default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class FileInfo(TypedDict):
path: str
size: int
modified: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
}
class UserManager():
def __init__(self):
global user_directory
user_directory = folder_paths.get_user_directory()
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(users_file):
with open(users_file) as f:
if os.path.isfile(self.get_users_file()):
with open(self.get_users_file()) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
@@ -44,7 +64,7 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory
user_directory = folder_paths.get_user_directory()
if type == "userdata":
root_dir = user_directory
@@ -59,6 +79,10 @@ class UserManager():
return None
if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
@@ -80,8 +104,7 @@ class UserManager():
self.users[user_id] = name
global users_file
with open(users_file, "w") as f:
with open(self.get_users_file(), "w") as f:
json.dump(self.users, f)
return user_id
@@ -112,25 +135,65 @@ class UserManager():
@routes.get("/userdata")
async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '')
if not directory:
return web.Response(status=400)
return web.Response(status=400, text="Directory not provided")
path = self.get_request_user_filepath(request, directory)
if not path:
return web.Response(status=403)
return web.Response(status=403, text="Invalid directory")
if not os.path.exists(path):
return web.Response(status=404)
return web.Response(status=404, text="Directory not found")
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join(
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path:
results = [[x] + x.split(os.sep) for x in results]
# Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
if full_info:
return get_file_info(full_path, path)
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
if split_path:
return [rel_path] + rel_path.split('/')
return rel_path
results = [
process_full_path(full_path)
for full_path in glob.glob(pattern, recursive=recurse)
if os.path.isfile(full_path)
]
return web.json_response(results)
@@ -138,14 +201,14 @@ class UserManager():
file = request.match_info.get(param, None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if check_exists and not os.path.exists(path):
return web.Response(status=404)
return path
@routes.get("/userdata/{file}")
@@ -153,25 +216,56 @@ class UserManager():
path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str):
return path
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):
"""
Upload or update a user data file.
This endpoint handles file uploads to a user's data directory, with options for
controlling overwrite behavior and response format.
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Path Parameters:
- file: The target file path (URL encoded if necessary).
Returns:
- 400: If 'file' parameter is missing.
- 403: If the requested path is not allowed.
- 409: If overwrite=false and the file already exists.
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
The request body should contain the raw file content to be written.
"""
path = get_user_data_path(request)
if not isinstance(path, str):
return path
overwrite = request.query["overwrite"] != "false"
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
if not overwrite and os.path.exists(path):
return web.Response(status=409)
return web.Response(status=409, text="File already exists")
body = await request.read()
with open(path, "wb") as f:
f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(path, user_path)
else:
resp = os.path.relpath(path, user_path)
return web.json_response(resp)
@routes.delete("/userdata/{file}")
@@ -181,25 +275,56 @@ class UserManager():
return path
os.remove(path)
return web.Response(status=204)
@routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request):
"""
Move or rename a user data file.
This endpoint handles moving or renaming files within a user's data directory, with options for
controlling overwrite behavior and response format.
Path Parameters:
- file: The source file path (URL encoded if necessary)
- dest: The destination file path (URL encoded if necessary)
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Returns:
- 400: If either 'file' or 'dest' parameter is missing
- 403: If either requested path is not allowed
- 404: If the source file does not exist
- 409: If overwrite=false and the destination file already exists
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
"""
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
return dest
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)
print(f"moving '{source}' -> '{dest}'")
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")
logging.info(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(dest, user_path)
else:
resp = os.path.relpath(dest, user_path)
return web.json_response(resp)

View File

@@ -2,11 +2,9 @@
#and modified
import torch
import torch as th
import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import (
zero_module,
timestep_embedding,
)
@@ -415,7 +413,6 @@ class ControlNet(nn.Module):
out_output = []
out_middle = []
hs = []
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

120
comfy/cldm/dit_embedder.py Normal file
View File

@@ -0,0 +1,120 @@
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y
x = x + self.pos_embed_input(hint)
block_out = ()
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat
return {"output": block_out}

View File

@@ -1,11 +1,12 @@
import torch
from typing import Dict, Optional
from typing import Optional
import comfy.ldm.modules.diffusionmodules.mmdit
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__(
self,
num_blocks = None,
control_latent_channels = None,
dtype = None,
device = None,
operations = None,
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None,
self.patch_size,
self.in_channels,
control_latent_channels,
self.hidden_size,
bias=True,
strict_img_size=False,

View File

@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
@@ -92,10 +94,17 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
@@ -112,10 +121,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -126,7 +139,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@@ -161,6 +174,8 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:
@@ -171,10 +186,3 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)

View File

@@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
}
class CLIPMLP(torch.nn.Module):
@@ -88,10 +89,11 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
@@ -123,7 +125,6 @@ class CLIPTextModel(torch.nn.Module):
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
@@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
num_patches = (image_size // patch_size) ** 2
if model_type == "siglip_vision_model":
self.class_embedding = None
patch_bias = True
else:
num_patches = num_patches + 1
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
patch_bias = False
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
bias=patch_bias,
dtype=dtype,
device=device
)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
if self.class_embedding is not None:
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
class CLIPVision(torch.nn.Module):
@@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
@@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :])
if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
if "projection_dim" in config_dict:
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
else:
self.visual_projection = lambda a: a
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)

View File

@@ -16,13 +16,18 @@ class Output:
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
@@ -35,6 +40,8 @@ class ClipVisionModel():
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -49,9 +56,9 @@ class ClipVisionModel():
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
@@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
@@ -109,8 +118,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
keys = list(sd.keys())
for k in keys:
if k not in u:
t = sd.pop(k)
del t
sd.pop(k)
return clip
def load(ckpt_path):

View File

@@ -0,0 +1,13 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@@ -0,0 +1,43 @@
# Comfy Typing
## Type hinting for ComfyUI Node development
This module provides type hinting and concrete convenience types for node developers.
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
```python
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
class ExampleNode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {"required": {}}
```
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
# Types
A few primary types are documented below. More complete information is available via the docstrings on each type.
## `IO`
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
- `ANY`: `"*"`
- `NUMBER`: `"FLOAT,INT"`
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
## `ComfyNodeABC`
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
### Type hinting for `INPUT_TYPES`
![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png)
### `INPUT_TYPES` return dict
![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png)
### Options for individual inputs
![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png)

View File

@@ -1,5 +1,6 @@
import torch
from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
class UnetApplyFunction(Protocol):
@@ -30,3 +31,15 @@ class UnetParams(TypedDict):
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
__all__ = [
"UnetWrapperFunction",
UnetApplyConds.__name__,
UnetParams.__name__,
UnetApplyFunction.__name__,
IO.__name__,
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
]

View File

@@ -0,0 +1,28 @@
from comfy_types import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
class ExampleNode(ComfyNodeABC):
"""An example node that just adds 1 to an input integer.
* Requires an IDE configured with analysis paths etc to be worth looking at.
* Not intended for use in ComfyUI.
"""
DESCRIPTION = cleandoc(__doc__)
CATEGORY = "examples"
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"input_int": (IO.INT, {"defaultInput": True}),
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("input_plus_one",)
FUNCTION = "execute"
def execute(self, input_int: int):
return (input_int + 1,)

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

View File

@@ -0,0 +1,274 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict
from abc import ABC, abstractmethod
from enum import Enum
class StrEnum(str, Enum):
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
def __str__(self) -> str:
return self.value
class IO(StrEnum):
"""Node input/output data types.
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
"""
STRING = "STRING"
IMAGE = "IMAGE"
MASK = "MASK"
LATENT = "LATENT"
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
GUIDER = "GUIDER"
NOISE = "NOISE"
CLIP = "CLIP"
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"
GLIGEN = "GLIGEN"
UPSCALE_MODEL = "UPSCALE_MODEL"
AUDIO = "AUDIO"
WEBCAM = "WEBCAM"
POINT = "POINT"
FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX"
SEGS = "SEGS"
ANY = "*"
"""Always matches any type, but at a price.
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
"""
NUMBER = "FLOAT,INT"
"""A float or an int - could be either"""
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
"""Could be any of: string, float, int, or bool"""
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
"""
default: bool | str | float | int | list | tuple
"""The default value of the widget"""
defaultInput: bool
"""Defaults to an input slot rather than a widget"""
forceInput: bool
"""`defaultInput` and also don't allow converting to a widget"""
lazy: bool
"""Declares that this input uses lazy evaluation"""
rawLink: bool
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
tooltip: str
"""Tooltip for the input (or widget), shown on pointer hover"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: float
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
max: float
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
step: float
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
round: float
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_on: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
multiline: bool
"""Use a multiline text box (``STRING``)"""
placeholder: str
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
dynamicPrompts: bool
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
node_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
unique_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt: Literal["PROMPT"]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo: Literal["EXTRA_PNGINFO"]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: Literal["DYNPROMPT"]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
"""
required: dict[str, tuple[IO, InputTypeOptions]]
"""Describes all inputs that must be connected for the node to execute."""
optional: dict[str, tuple[IO, InputTypeOptions]]
"""Describes inputs which do not need to be connected."""
hidden: HiddenInputTypeDict
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
"""
DESCRIPTION: str
"""Node description, shown as a tooltip when hovering over the node.
Usage::
# Explicitly define the description
DESCRIPTION = "Example description here."
# Use the docstring of the node class.
DESCRIPTION = cleandoc(__doc__)
"""
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
@classmethod
@abstractmethod
def INPUT_TYPES(s) -> InputTypeDict:
"""Defines node inputs.
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
"""
return {"required": {}}
OUTPUT_NODE: bool
"""Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage::
OUTPUT_NODE = True
From the docs:
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
From the docs:
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
From the docs:
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
RETURN_TYPES: tuple[IO]
"""A tuple representing the outputs of this node.
Usage::
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
"""
RETURN_NAMES: tuple[str]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
"""
class CheckLazyMixin:
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
def check_lazy_status(self, **kwargs) -> list[str]:
"""Returns a list of input names that should be evaluated.
This basic mixin impl. requires all inputs.
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need

View File

@@ -1,4 +1,24 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from enum import Enum
import math
import os
import logging
@@ -13,6 +33,12 @@ import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.hooks import HookGroup
def broadcast_image_to(tensor, target_batch_size, batched_number):
@@ -33,8 +59,12 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else:
return torch.cat([tensor] * batched_number, dim=0)
class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlBase:
def __init__(self, device=None):
def __init__(self):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
@@ -46,18 +76,26 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self
def pre_run(self, model, percent_to_timestep_function):
@@ -72,9 +110,9 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
@@ -82,6 +120,14 @@ class ControlBase:
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
out.append(self.extra_hooks)
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
@@ -93,6 +139,12 @@ class ControlBase:
c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy()
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@@ -113,9 +165,12 @@ class ControlBase:
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x)
x *= self.strength
if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype:
if output_dtype is not None and x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
@@ -142,8 +197,8 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
super().__init__()
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
@@ -154,11 +209,15 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@@ -171,7 +230,6 @@ class ControlNet(ControlBase):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
@@ -179,26 +237,41 @@ class ControlNet(ControlBase):
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
extra = self.extra_args.copy()
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
return self.control_merge(control, control_prev, output_dtype)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@@ -224,7 +297,6 @@ class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
@@ -282,10 +354,11 @@ class ControlLoraOps:
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
ControlBase.__init__(self, device)
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
ControlBase.__init__(self)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
@@ -308,7 +381,6 @@ class ControlLora(ControlNet):
self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = sd[k]
@@ -338,43 +410,191 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: #inpaint controlnet
concat_mask = True
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())
# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2
new_sd = {}
for k in comfy.utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = comfy.latent_formats.SDXL()
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd:
new_sd[k] = sd[k]
num_union_modes = 0
union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
return ControlLora(controlnet_data, model_options=model_options)
controlnet_config = None
supported_inference_dtypes = None
@@ -429,8 +649,18 @@ def load_controlnet(ckpt_path, model=None):
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
return load_controlnet_mmdit(controlnet_data)
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
@@ -442,26 +672,38 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
logging.error("error could not detect control model type.")
return net
if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@@ -495,22 +737,32 @@ def load_controlnet(ckpt_path, model=None):
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
global_average_pooling = model_options.get("global_average_pooling", False)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None, model_options={}):
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
super().__init__()
self.t2i_model = t2i_model
self.channels_in = channels_in
self.control_input = None
self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
@@ -518,10 +770,10 @@ class T2IAdapter(ControlBase):
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height
def get_control(self, x_noisy, t, cond, batched_number):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@@ -558,7 +810,7 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'
@@ -569,7 +821,7 @@ def load_t2i_adapter(t2i_data):
for i in range(4):
for j in range(2):
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
prefix_replace["adapter."] = ""
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
keys = t2i_data.keys()

View File

@@ -157,16 +157,23 @@ vae_conversion_map_attn = [
]
def reshape_weight_for_sd(w):
def reshape_weight_for_sd(w, conv3d=False):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
if conv3d:
return w.reshape(*w.shape, 1, 1, 1)
else:
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
conv3d = False
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
if v.endswith(".conv.weight"):
if not conv3d and vae_state_dict[k].ndim == 5:
conv3d = True
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
@@ -179,7 +186,7 @@ def convert_vae_state_dict(vae_state_dict):
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
return new_state_dict

View File

@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path)
unet = comfy.sd.load_unet(unet_path)
unet = comfy.sd.load_diffusion_model(unet_path)
clip = None
if output_clip:

View File

@@ -1,10 +1,9 @@
#code taken from: https://github.com/wl-zhao/UniPC and modified
import torch
import torch.nn.functional as F
import math
from tqdm.auto import trange, tqdm
from tqdm.auto import trange
class NoiseScheduleVP:
@@ -16,7 +15,7 @@ class NoiseScheduleVP:
continuous_beta_0=0.1,
continuous_beta_1=20.,
):
"""Create a wrapper class for the forward SDE (VP type).
r"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
@@ -704,7 +703,6 @@ class UniPC:
):
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
# t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
steps = len(timesteps) - 1
if method == 'multistep':
assert steps >= order

67
comfy/float.py Normal file
View File

@@ -0,0 +1,67 @@
import torch
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
normal_mask = ~(exponent == 0)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
inf = torch.finfo(dtype)
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
return sign
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
return value.to(dtype=torch.float16)
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype)

View File

@@ -1,3 +1,4 @@
import math
import torch
from torch import nn
from .ldm.modules.attention import CrossAttention

690
comfy/hooks.py Normal file
View File

@@ -0,0 +1,690 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import enum
import math
import torch
import numpy as np
import itertools
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher, PatcherInjection
from comfy.model_base import BaseModel
from comfy.sd import CLIP
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
class EnumHookMode(enum.Enum):
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class _HookRef:
pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return True
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
self.hook_type = hook_type
self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_id = hook_id
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: Hook = subtype()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:
strength = self._strength_clip
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
weights = self.weights_clip
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WeightHook = super().clone(subtype)
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
c._strength_model = self._strength_model
c._strength_clip = self._strength_clip
return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
c.object_patches = self.object_patches
return c
# TODO: add functionality
class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
return True
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
self.key = key
self.injections = injections
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
class HookGroup:
def __init__(self):
self.hooks: list[Hook] = []
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
def contains(self, hook: Hook):
return hook in self.hooks
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: 'HookGroup'):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks:
# only care about WeightHooks, for now
if hook.hook_type == EnumHookType.Weight:
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values():
for t_range, keyframe in range_kfs:
all_ranges.append(t_range)
# turn list of ranges into boundaries
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
boundaries_set.add(0.0)
boundaries = sorted(boundaries_set)
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
# with real ranges defined, give appropriate hooks w/ keyframes for each range
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
for t_range in real_ranges:
hooks_schedule = []
for hook, val in scheduled_hooks.items():
keyframe = None
# check if is a keyframe that works for the current t_range
for stored_range, stored_kf in val:
# if stored start is less than current end, then fits - give it assigned keyframe
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
keyframe = stored_kf
break
hooks_schedule.append((hook, keyframe))
scheduled_keyframes.append((t_range, hooks_schedule))
return scheduled_keyframes
def reset(self):
for hook in self.hooks:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
# if no hooks, then return None
if len(actual) == 0:
return None
# if only 1 hook, just return itself without cloning
elif len(actual) == 1:
return actual[0]
final_hook: HookGroup = None
for hook in actual:
if final_hook is None:
final_hook = hook.clone()
else:
final_hook = final_hook.clone_and_combine(hook)
return final_hook
class HookKeyframe:
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
self.strength = strength
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class HookKeyframeGroup:
def __init__(self):
self.keyframes: list[HookKeyframe] = []
self._current_keyframe: HookKeyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self._curr_t = -1.
# properties shadow those of HookWeightsKeyframe
@property
def strength(self):
if self._current_keyframe is not None:
return self._current_keyframe.strength
return 1.0
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
prev_index = self._current_index
prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_strength = eval_c.strength
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on
self._curr_t = curr_t
# return True if keyframe changed, False if no change
return prev_index != self._current_index and prev_strength != self._current_strength
class InterpolationMethod:
LINEAR = "linear"
EASE_IN = "ease_in"
EASE_OUT = "ease_out"
EASE_IN_OUT = "ease_in_out"
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
@classmethod
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
diff = num_to - num_from
if method == cls.LINEAR:
weights = torch.linspace(num_from, num_to, length)
elif method == cls.EASE_IN:
index = torch.linspace(0, 1, length)
weights = diff * np.power(index, 2) + num_from
elif method == cls.EASE_OUT:
index = torch.linspace(0, 1, length)
weights = diff * (1 - np.power(1 - index, 2)) + num_from
elif method == cls.EASE_IN_OUT:
index = torch.linspace(0, 1, length)
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
else:
raise ValueError(f"Unrecognized interpolation method '{method}'.")
if reverse:
weights = weights.flip(dims=(0,))
return weights
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
if not objects:
return objects
elif len(objects) <= 1:
return [x for x in objects]
# now that we know we have to sort, do it following these rules:
# a) if objects have same value of attribute, maintain their relative order
# b) perform sorting of the groups of objects with same attributes
unique_attrs = {}
for o in objects:
val_attr = getattr(o, attr)
attr_list: list = unique_attrs.get(val_attr, list())
attr_list.append(o)
if val_attr not in unique_attrs:
unique_attrs[val_attr] = attr_list
# now that we have the unique attr values grouped together in relative order, sort them by key
sorted_attrs = dict(sorted(unique_attrs.items()))
# now flatten out the dict into a list to return
sorted_list = []
for object_list in sorted_attrs.values():
sorted_list.extend(object_list)
return sorted_list
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
hook.weights = lora
return hook_group
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
patches_model = None
patches_clip = None
if weights_model is not None:
patches_model = {}
for key in weights_model:
patches_model[key] = ("model_as_lora", (weights_model[key],))
if weights_clip is not None:
patches_clip = {}
for key in weights_clip:
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
hook.weights = patches_model
hook.weights_clip = patches_clip
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
if discard_model_sampling:
# do not include ANY model_sampling components of the model that should act as a patch
for key in list(patches_model.keys()):
if key.startswith("model_sampling"):
patches_model.pop(key, None)
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
hook_group = HookGroup()
hook = WeightHook()
hook_group.add(hook)
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
hooks_key = 'hooks'
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
if hooks_key not in values:
return
if hooks_key not in c_dict:
hooks_value = values.get(hooks_key, None)
if hooks_value is not None:
c_dict[hooks_key] = hooks_value
return
# otherwise, need to combine with minimum duplication via cache
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
cached_hooks = cache.get(hooks_tuple, None)
if cached_hooks is None:
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
cache[hooks_tuple] = new_hooks
c_dict[hooks_key] = new_hooks
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
return cond
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
"end_percent": timestep_range[1]})
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
if mask is None:
return cond
set_area_to_bounds = False
if set_cond_area != 'default':
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
return conditioning_set_values(cond, {'mask': mask,
'set_area_to_bounds': set_area_to_bounds,
'mask_strength': strength})
def combine_conditioning(conds: list):
combined_conds = []
for cond in conds:
combined_conds.extend(cond)
return combined_conds
def combine_with_new_conds(conds: list, new_conds: list):
combined_conds = []
for c, new_c in zip(conds, new_conds):
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
# finally, apply mask to conditioning and store
final_conds.append(c)
return final_conds
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds

View File

@@ -11,7 +11,6 @@ import numpy as np
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d

View File

@@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
import comfy.model_patcher
import comfy.model_sampling
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
@@ -43,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
@@ -152,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -161,14 +175,42 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if sigma_down == 0:
x = denoised
else:
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
sigma_down = sigmas[i + 1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i + 1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if eta > 0:
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
@@ -243,6 +285,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -269,6 +314,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
@@ -509,6 +586,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -541,6 +621,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
# logged_x = x.unsqueeze(0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
if sigmas[i] == 1.0:
sigma_s = 0.9999
else:
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
r = 1 / 2
h = t_down - t_i
s = t_i + r * h
sigma_s = sigma_fn(s)
# sigma_s = sigmas[i+1]
sigma_s_i_ratio = sigma_s / sigmas[i]
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
D_i = model(u, sigma_s * s_in, **extra_args)
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
# Noise addition
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
@@ -1016,7 +1145,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = denoised + d * sigmas[i + 1]
return x
@@ -1043,8 +1171,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], temp[0])
# Euler method
dt = sigma_down - sigmas[i]
x = denoised + d * sigma_down
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
t_fn = lambda sigma: sigma.log().neg()
old_uncond_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_uncond_denoised is None or sigmas[i + 1] == 0:
denoised_mix = -torch.exp(-h) * uncond_denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x

View File

@@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None
def process_in(self, latent):
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
[ 0.3651, 0.4232, 0.4341],
[-0.2533, -0.0042, 0.1068],
[ 0.1076, 0.1111, -0.0362],
[-0.3165, -0.2492, -0.2188]
]
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.latent_rgb_factors = [
[-0.0645, 0.0177, 0.1052],
[ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360],
[ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259]
[-0.0922, -0.0175, 0.0749],
[ 0.0311, 0.0633, 0.0954],
[ 0.1994, 0.0927, 0.0458],
[ 0.0856, 0.0339, 0.0902],
[ 0.0587, 0.0272, -0.0496],
[-0.0006, 0.1104, 0.0309],
[ 0.0978, 0.0306, 0.0427],
[-0.0042, 0.1038, 0.1358],
[-0.0194, 0.0020, 0.0669],
[-0.0488, 0.0130, -0.0268],
[ 0.0922, 0.0988, 0.0951],
[-0.0278, 0.0524, -0.0542],
[ 0.0332, 0.0456, 0.0895],
[-0.0069, -0.0030, -0.0810],
[-0.0596, -0.0465, -0.0293],
[-0.1448, -0.1463, -0.1189]
]
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent):
@@ -141,30 +145,234 @@ class StableAudio1(LatentFormat):
latent_channels = 64
class Flux(SD3):
latent_channels = 16
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0404, 0.0159, 0.0609],
[ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503],
[-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530],
[ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262],
[-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001],
[ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013],
[ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976],
[-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
[-0.0346, 0.0244, 0.0681],
[ 0.0034, 0.0210, 0.0687],
[ 0.0275, -0.0668, -0.0433],
[-0.0174, 0.0160, 0.0617],
[ 0.0859, 0.0721, 0.0329],
[ 0.0004, 0.0383, 0.0115],
[ 0.0405, 0.0861, 0.0915],
[-0.0236, -0.0185, -0.0259],
[-0.0245, 0.0250, 0.1180],
[ 0.1008, 0.0755, -0.0421],
[-0.0515, 0.0201, 0.0011],
[ 0.0428, -0.0012, -0.0036],
[ 0.0817, 0.0765, 0.0749],
[-0.1264, -0.0522, -0.1103],
[-0.0280, -0.0881, -0.0499],
[-0.1262, -0.0982, -0.0778]
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class Mochi(LatentFormat):
latent_channels = 12
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
0.959253732819592, 0.8244560132752793, 0.917259975397747,
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
self.latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class LTXV(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class HunyuanVideo(LatentFormat):
latent_channels = 16
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
[ 0.0696, 0.0795, 0.0518],
[ 0.0135, -0.0945, -0.0282],
[ 0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[ 0.1166, 0.1627, 0.0962],
[ 0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[ 0.0249, -0.0469, -0.1703]
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]

View File

@@ -2,7 +2,7 @@
import torch
from torch import nn
from typing import Literal, Dict, Any
from typing import Literal
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -97,7 +97,7 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False,
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
act = Activation1d(act) # noqa: F821 Activation1d is not defined
return act

View File

@@ -158,7 +158,6 @@ class RotaryEmbedding(nn.Module):
def forward(self, t):
# device = self.inv_freq.device
device = t.device
dtype = t.dtype
# t = t.to(torch.float32)
@@ -170,7 +169,7 @@ class RotaryEmbedding(nn.Module):
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
@@ -229,9 +228,9 @@ class FeedForward(nn.Module):
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
else:
linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
@@ -246,9 +245,9 @@ class FeedForward(nn.Module):
self.ff = nn.Sequential(
linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):
@@ -346,18 +345,13 @@ class Attention(nn.Module):
# determine masking
masks = []
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
masks.append(~input_mask)
# Other masks will be added here later
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
n, device = q.shape[-2], q.device
n = q.shape[-2]
causal = self.causal if causal is None else causal
@@ -612,7 +606,9 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
info = {
"hidden_states": [],
@@ -643,9 +639,19 @@ class ContinuousTransformer(nn.Module):
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
blocks_replace = patches_replace.get("dit", {})
# Iterate over the transformer layers
for layer in self.layers:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
for i, layer in enumerate(self.layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:
@@ -874,7 +880,6 @@ class AudioDiffusionTransformer(nn.Module):
mask=None,
return_info=False,
control=None,
transformer_options={},
**kwargs):
return self._forward(
x,

View File

@@ -2,8 +2,8 @@
import torch
import torch.nn as nn
from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from torch import Tensor
from typing import List, Union
from einops import rearrange
import math
import comfy.ops

View File

@@ -147,7 +147,6 @@ class DoubleAttention(nn.Module):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
seqlen = seqlen1 + seqlen2
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
@@ -437,7 +436,8 @@ class MMDiT(nn.Module):
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, **kwargs):
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape
@@ -458,15 +458,36 @@ class MMDiT(nn.Module):
global_cond = self.t_embedder(t, x.dtype) # B, D
blocks_replace = patches_replace.get("dit", {})
if len(self.double_layers) > 0:
for layer in self.double_layers:
c, x = layer(c, x, global_cond, **kwargs)
for i, layer in enumerate(self.double_layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
cx = torch.cat([c, x], dim=1)
for layer in self.single_layers:
cx = layer(cx, global_cond, **kwargs)
for i, layer in enumerate(self.single_layers):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
x = cx[:, c_len:]

View File

@@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
from .common import LayerNorm2d_op

View File

@@ -1,8 +1,30 @@
import torch
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
pad = ()
for i in range(img.ndim - 2):
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
return torch.nn.functional.pad(img, pad, mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

View File

@@ -0,0 +1,203 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
#modified to support different types of flux controlnets
import torch
import math
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class MistolineCondDownsamplBlock(nn.Module):
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
self.encoder = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x):
return self.encoder(x)
class MistolineControlnetBlock(nn.Module):
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.linear(x))
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
self.main_model_single = 38
self.mistoline = mistoline
# add ControlNet blocks
if self.mistoline:
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
self.controlnet_blocks.append(control_block())
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(control_block())
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
if self.num_union_modes > 0:
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.latent_input = latent_input
if control_latent_channels is None:
control_latent_channels = self.in_channels
else:
control_latent_channels *= 2 * 2 #patch size
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
if self.mistoline:
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
else:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
controlnet_double = ()
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
img = torch.cat((txt, img), 1)
controlnet_single = ()
for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input:
out_input = ()
for x in controlnet_double:
out_input += (x,) * repeat
else:
out_input = (controlnet_double * repeat)
out = {"input": out_input[:self.main_model_double]}
if len(controlnet_single) > 0:
repeat = math.ceil(self.main_model_single / len(controlnet_single))
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
elif self.mistoline:
hint = hint * 2.0 - 1.0
hint = self.input_cond_block(hint)
else:
hint = hint * 2.0 - 1.0
hint = self.input_hint_block(hint)
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

View File

@@ -2,11 +2,12 @@ import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
@@ -36,9 +37,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -48,7 +47,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
super().__init__()
@@ -66,10 +64,7 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
@@ -94,14 +89,6 @@ class SelfAttention(nn.Module):
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
@@ -127,7 +114,7 @@ class Modulation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -154,8 +141,9 @@ class DoubleStreamBlock(nn.Module):
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -163,31 +151,44 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
if self.flipped_img_txt:
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
@@ -227,19 +228,22 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):

View File

@@ -1,14 +1,15 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
@@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -4,6 +4,8 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from .layers import (
DoubleStreamBlock,
@@ -14,12 +16,10 @@ from .layers import (
timestep_embedding,
)
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
@@ -29,6 +29,7 @@ class FluxParams:
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
guidance_embed: bool
@@ -38,13 +39,14 @@ class Flux(nn.Module):
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
self.patch_size = params.patch_size
self.in_channels = params.in_channels * params.patch_size * params.patch_size
self.out_channels = params.out_channels * params.patch_size * params.patch_size
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -83,7 +85,8 @@ class Flux(nn.Module):
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
@@ -94,7 +97,11 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -106,26 +113,82 @@ class Flux(nn.Module):
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
@@ -133,10 +196,10 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

25
comfy/ldm/flux/redux.py Normal file
View File

@@ -0,0 +1,25 @@
import torch
import comfy.ops
ops = comfy.ops.manual_cast
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x

View File

@@ -0,0 +1,557 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# from flash_attn import flash_attn_varlen_qkvpacked_func
from comfy.ldm.modules.attention import optimized_attention
from .layers import (
FeedForward,
PatchEmbed,
RMSNorm,
TimestepEmbedder,
)
from .rope_mixed import (
compute_mixed_rotation,
create_position_matrix,
)
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import (
AttentionPool,
modulate,
)
import comfy.ldm.common_dit
import comfy.ops
def modulated_rmsnorm(x, scale, eps=1e-6):
# Normalize and modulate
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
x_modulated = x_normed * (1 + scale.unsqueeze(1))
return x_modulated
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
# Apply tanh to gate
tanh_gate = torch.tanh(gate).unsqueeze(1)
# Normalize and apply gated scaling
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
# Apply residual connection
output = x + x_normed
return output
class AsymmetricAttention(nn.Module):
def __init__(
self,
dim_x: int,
dim_y: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
update_y: bool = True,
out_bias: bool = True,
attend_to_padding: bool = False,
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
self.dim_x = dim_x
self.dim_y = dim_y
self.num_heads = num_heads
self.head_dim = dim_x // num_heads
self.attn_drop = attn_drop
self.update_y = update_y
self.attend_to_padding = attend_to_padding
self.softmax_scale = softmax_scale
if dim_x % num_heads != 0:
raise ValueError(
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
)
# Input layers.
self.qkv_bias = qkv_bias
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
# Project text features to match visual features (dim_y -> dim_x)
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
# Query and key normalization for stability.
assert qk_norm
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
# Output layers. y features go back down from dim_x -> dim_y.
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
self.proj_y = (
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
if update_y
else nn.Identity()
)
def forward(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
rope_sin = rope_rotation.get("rope_sin")
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process visual features
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
# assert qkv_x.dtype == torch.bfloat16
# qkv_x = all_to_all_collect_tokens(
# qkv_x, self.num_heads
# ) # (3, B, N, local_h, head_dim)
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Split qkv_x into q, k, v
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
xy = optimized_attention(q,
k,
v, self.num_heads, skip_reshape=True)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
o[:, :y.shape[1]] = y
y = self.proj_y(o)
# print("ox", x)
# print("oy", y)
return x, y
class AsymmetricJointBlock(nn.Module):
def __init__(
self,
hidden_size_x: int,
hidden_size_y: int,
num_heads: int,
*,
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
):
super().__init__()
self.update_y = update_y
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
if self.update_y:
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
else:
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
# Self-attention:
self.attn = AsymmetricAttention(
hidden_size_x,
hidden_size_y,
num_heads=num_heads,
update_y=update_y,
device=device,
dtype=dtype,
operations=operations,
**block_kwargs,
)
# MLP.
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
assert mlp_hidden_dim_x == int(1536 * 8)
self.mlp_x = FeedForward(
in_features=hidden_size_x,
hidden_size=mlp_hidden_dim_x,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
dtype=dtype,
operations=operations,
)
# MLP for text not needed in last block.
if self.update_y:
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
self.mlp_y = FeedForward(
in_features=hidden_size_y,
hidden_size=mlp_hidden_dim_y,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
**attn_kwargs,
):
"""Forward pass of a block.
Args:
x: (B, N, dim) tensor of visual tokens
c: (B, dim) tensor of conditioned features
y: (B, L, dim) tensor of text tokens
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
Returns:
x: (B, N, dim) tensor of visual tokens after block
y: (B, L, dim) tensor of text tokens after block
"""
N = x.size(1)
c = F.silu(c)
mod_x = self.mod_x(c)
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
mod_y = self.mod_y(c)
if self.update_y:
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
else:
scale_msa_y = mod_y
# Self-attention block.
x_attn, y_attn = self.attn(
x,
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
**attn_kwargs,
)
assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
if self.update_y:
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
# MLP block.
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
if self.update_y:
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
return x, y
def ff_block_x(self, x, scale_x, gate_x):
x_mod = modulated_rmsnorm(x, scale_x)
x_res = self.mlp_x(x_mod)
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
return x
def ff_block_y(self, y, scale_y, gate_y):
y_mod = modulated_rmsnorm(y, scale_y)
y_res = self.mlp_y(y_mod)
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
return y
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(
self,
hidden_size,
patch_size,
out_channels,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
)
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
self.linear = operations.Linear(
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
)
def forward(self, x, c):
c = F.silu(c)
shift, scale = self.mod(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class AsymmDiTJoint(nn.Module):
"""
Diffusion model with a Transformer backbone.
Ingests text embeddings instead of a label.
"""
def __init__(
self,
*,
patch_size=2,
in_channels=4,
hidden_size_x=1152,
hidden_size_y=1152,
depth=48,
num_heads=16,
mlp_ratio_x=8.0,
mlp_ratio_y=4.0,
use_t5: bool = False,
t5_feat_dim: int = 4096,
t5_token_length: int = 256,
learn_sigma=True,
patch_embed_bias: bool = True,
timestep_mlp_bias: bool = True,
attend_to_padding: bool = False,
timestep_scale: Optional[float] = None,
use_extended_posenc: bool = False,
posenc_preserve_area: bool = False,
rope_theta: float = 10000.0,
image_model=None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
):
super().__init__()
self.dtype = dtype
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.head_dim = (
hidden_size_x // num_heads
) # Head dimension and count is determined by visual.
self.attend_to_padding = attend_to_padding
self.use_extended_posenc = use_extended_posenc
self.posenc_preserve_area = posenc_preserve_area
self.use_t5 = use_t5
self.t5_token_length = t5_token_length
self.t5_feat_dim = t5_feat_dim
self.rope_theta = (
rope_theta # Scaling factor for frequency computation for temporal RoPE.
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size_x,
bias=patch_embed_bias,
dtype=dtype,
device=device,
operations=operations
)
# Conditionings
# Timestep
self.t_embedder = TimestepEmbedder(
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
if self.use_t5:
# Caption Pooling (T5)
self.t5_y_embedder = AttentionPool(
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
)
# Dense Embedding Projection (T5)
self.t5_yproj = operations.Linear(
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
)
# Initialize pos_frequencies as an empty parameter.
self.pos_frequencies = nn.Parameter(
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
)
assert not self.attend_to_padding
# for depth 48:
# b = 0: AsymmetricJointBlock, update_y=True
# b = 1: AsymmetricJointBlock, update_y=True
# ...
# b = 46: AsymmetricJointBlock, update_y=True
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
blocks = []
for b in range(depth):
# Joint multi-modal block
update_y = b < depth - 1
block = AsymmetricJointBlock(
hidden_size_x,
hidden_size_y,
num_heads,
mlp_ratio_x=mlp_ratio_x,
mlp_ratio_y=mlp_ratio_y,
update_y=update_y,
attend_to_padding=attend_to_padding,
device=device,
dtype=dtype,
operations=operations,
**block_kwargs,
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.final_layer = FinalLayer(
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C=12, T, H, W) tensor of visual tokens
Returns:
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
"""
return self.x_embedder(x) # Convert BcTHW to BCN
def prepare(
self,
x: torch.Tensor,
sigma: torch.Tensor,
t5_feat: torch.Tensor,
t5_mask: torch.Tensor,
):
"""Prepare input and conditioning embeddings."""
# Visual patch embeddings with positional encoding.
T, H, W = x.shape[-3:]
pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
assert x.ndim == 3
pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW
assert x.size(1) == N
pos = create_position_matrix(
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation(
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
) # Each are (N, num_heads, dim // 2)
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
c = c_t + t5_y_pool
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
return x, c, y_feat, rope_cos, rope_sin
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: List[torch.Tensor],
attention_mask: List[torch.Tensor],
num_tokens=256,
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
control=None, transformer_options={}, **kwargs
):
patches_replace = transformer_options.get("patches_replace", {})
y_feat = context
y_mask = attention_mask
sigma = timestep
"""Forward pass of DiT.
Args:
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
"""
B, _, T, H, W = x.shape
x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat, y_mask
)
del y_mask
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
args["img"],
args["vec"],
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
x = rearrange(
x,
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
T=T,
hp=H // self.patch_size,
wp=W // self.patch_size,
p1=self.patch_size,
p2=self.patch_size,
c=self.out_channels,
)
return -x

View File

@@ -0,0 +1,164 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
import collections.abc
import math
from itertools import repeat
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.ldm.common_dit
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
class TimestepEmbedder(nn.Module):
def __init__(
self,
hidden_size: int,
frequency_embedding_size: int = 256,
*,
bias: bool = True,
timestep_scale: Optional[float] = None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.timestep_scale = timestep_scale
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
freqs.mul_(-math.log(max_period) / half).exp_()
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t, out_dtype):
if self.timestep_scale is not None:
t = t * self.timestep_scale
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForward(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
# keep parameter count and computation constant compared to standard FFN
hidden_size = int(2 * hidden_size / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_size = int(ffn_dim_multiplier * hidden_size)
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
self.hidden_dim = hidden_size
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
def forward(self, x):
x, gate = self.w1(x).chunk(2, dim=-1)
x = self.w2(F.silu(x) * gate)
return x
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
bias: bool = True,
dynamic_img_pad: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
self.flatten = flatten
self.dynamic_img_pad = dynamic_img_pad
self.proj = operations.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
device=device,
dtype=dtype,
)
assert norm_layer is None
self.norm = (
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
)
def forward(self, x):
B, _C, T, H, W = x.shape
if not self.dynamic_img_pad:
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
else:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
x = self.proj(x)
# Flatten temporal and spatial dimensions.
if not self.flatten:
raise NotImplementedError("Must flatten output.")
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
x = self.norm(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
self.register_parameter("bias", None)
def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)

View File

@@ -0,0 +1,88 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
# import functools
import math
import torch
def centers(start: float, stop, num, dtype=None, device=None):
"""linspace through bin centers.
Args:
start (float): Start of the range.
stop (float): End of the range.
num (int): Number of points.
dtype (torch.dtype): Data type of the points.
device (torch.device): Device of the points.
Returns:
centers (Tensor): Centers of the bins. Shape: (num,).
"""
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
return (edges[:-1] + edges[1:]) / 2
# @functools.lru_cache(maxsize=1)
def create_position_matrix(
T: int,
pH: int,
pW: int,
device: torch.device,
dtype: torch.dtype,
*,
target_area: float = 36864,
):
"""
Args:
T: int - Temporal dimension
pH: int - Height dimension after patchify
pW: int - Width dimension after patchify
Returns:
pos: [T * pH * pW, 3] - position matrix
"""
# Create 1D tensors for each dimension
t = torch.arange(T, dtype=dtype)
# Positionally interpolate to area 36864.
# (3072x3072 frame with 16x16 patches = 192x192 latents).
# This automatically scales rope positions when the resolution changes.
# We use a large target area so the model is more sensitive
# to changes in the learned pos_frequencies matrix.
scale = math.sqrt(target_area / (pW * pH))
w = centers(-pW * scale / 2, pW * scale / 2, pW)
h = centers(-pH * scale / 2, pH * scale / 2, pH)
# Use meshgrid to create 3D grids
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
# Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device)
return pos
def compute_mixed_rotation(
freqs: torch.Tensor,
pos: torch.Tensor,
):
"""
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
Args:
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
pos: [N, 3] - position of each token
num_heads: int
Returns:
freqs_cos: [N, num_heads, num_freqs] - cosine components
freqs_sin: [N, num_heads, num_freqs] - sine components
"""
assert freqs.ndim == 3
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
freqs_cos = torch.cos(freqs_sum)
freqs_sin = torch.sin(freqs_sum)
return freqs_cos, freqs_sin

View File

@@ -0,0 +1,34 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
# Based on Llama3 Implementation.
import torch
def apply_rotary_emb_qk_real(
xqk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
Args:
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
Can be either just query or just key, or both stacked along some batch or * dim.
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
Returns:
torch.Tensor: The input tensor with rotary embeddings applied.
"""
# Split the last dimension into even and odd parts
xqk_even = xqk[..., 0::2]
xqk_odd = xqk[..., 1::2]
# Apply rotation
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
# Interleave the results back into the original shape
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
return out

View File

@@ -0,0 +1,102 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
"""
Pool tokens in x using mask.
NOTE: We assume x does not require gradients.
Args:
x: (B, L, D) tensor of tokens.
mask: (B, L) boolean tensor indicating which tokens are not padding.
Returns:
pooled: (B, D) tensor of pooled tokens.
"""
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
class AttentionPool(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
output_dim: int = None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
"""
Args:
spatial_dim (int): Number of tokens in sequence length.
embed_dim (int): Dimensionality of input tokens.
num_heads (int): Number of attention heads.
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
"""
super().__init__()
self.num_heads = num_heads
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
def forward(self, x, mask):
"""
Args:
x (torch.Tensor): (B, L, D) tensor of input tokens.
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
NOTE: We assume x does not require gradients.
Returns:
x (torch.Tensor): (B, D) tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_heads
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x

View File

@@ -0,0 +1,711 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import List, Optional, Tuple, Union
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
# import mochi_preview.dit.joint_model.context_parallel as cp
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
class GroupNormSpatial(ops.GroupNorm):
"""
GroupNorm applied per-frame.
"""
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
B, C, T, H, W = x.shape
x = rearrange(x, "B C T H W -> (B T) C H W")
# Run group norm in chunks.
output = torch.empty_like(x)
for b in range(0, B * T, chunk_size):
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
class PConv3d(ops.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
causal: bool = True,
context_parallel: bool = True,
**kwargs,
):
self.causal = causal
self.context_parallel = context_parallel
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
# Compute padding amounts.
context_size = self.kernel_size[0] - 1
if self.causal:
pad_front = context_size
pad_back = 0
else:
pad_front = context_size // 2
pad_back = context_size - pad_front
# Apply padding.
assert self.padding_mode == "replicate" # DEBUG
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
return super().forward(x)
class Conv1x1(ops.Linear):
"""*1x1 Conv implemented with a linear layer."""
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
super().__init__(in_features, out_features, *args, **kwargs)
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, *] or [B, *, C].
Returns:
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
"""
x = x.movedim(1, -1)
x = super().forward(x)
x = x.movedim(-1, 1)
return x
class DepthToSpaceTime(nn.Module):
def __init__(
self,
temporal_expansion: int,
spatial_expansion: int,
):
super().__init__()
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
# When printed, this module should show the temporal and spatial expansion factors.
def extra_repr(self):
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
Returns:
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
"""
x = rearrange(
x,
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
st=self.temporal_expansion,
sh=self.spatial_expansion,
sw=self.spatial_expansion,
)
# cp_rank, _ = cp.get_cp_rank_size()
if self.temporal_expansion > 1: # and cp_rank == 0:
# Drop the first self.temporal_expansion - 1 frames.
# This is because we always want the 3x3x3 conv filter to only apply
# to the first frame, and the first frame doesn't need to be repeated.
assert all(x.shape)
x = x[:, :, self.temporal_expansion - 1 :]
assert all(x.shape)
return x
def norm_fn(
in_channels: int,
affine: bool = True,
):
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
class ResBlock(nn.Module):
"""Residual block that preserves the spatial dimensions."""
def __init__(
self,
channels: int,
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
):
super().__init__()
self.channels = channels
assert causal
self.stack = nn.Sequential(
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
out_channels=channels // 2 if prune_bottleneck else channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=bias,
causal=causal,
),
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels // 2 if prune_bottleneck else channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=bias,
causal=causal,
),
)
self.attn_block = attn_block if attn_block else nn.Identity()
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
"""
residual = x
x = self.stack(x)
x = x + residual
del residual
return self.attn_block(x)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
) -> None:
super().__init__()
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.qk_norm = qk_norm
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.out = nn.Linear(dim, dim, bias=out_bias)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""Compute temporal self-attention.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
chunk_size: Chunk size for large tensors.
Returns:
x: Output tensor. Shape: [B, C, T, H, W].
"""
B, _, T, H, W = x.shape
if T == 1:
# No attention for single frame.
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
qkv = self.qkv(x)
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
x = self.out(x)
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
# 1D temporal attention.
x = rearrange(x, "B C t h w -> (B h w) t C")
qkv = self.qkv(x)
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
# Output: x with shape [B, num_heads, t, head_dim]
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
if self.qk_norm:
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
assert x.size(0) == q.size(0)
x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
**attn_kwargs,
) -> None:
super().__init__()
self.norm = norm_fn(dim)
self.attn = Attention(dim, **attn_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.attn(self.norm(x))
class CausalUpsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
*,
temporal_expansion: int = 2,
spatial_expansion: int = 2,
**block_kwargs,
):
super().__init__()
blocks = []
for _ in range(num_res_blocks):
blocks.append(block_fn(in_channels, **block_kwargs))
self.blocks = nn.Sequential(*blocks)
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
# Change channels in the final convolution layer.
self.proj = Conv1x1(
in_channels,
out_channels * temporal_expansion * (spatial_expansion**2),
)
self.d2st = DepthToSpaceTime(
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
)
def forward(self, x):
x = self.blocks(x)
x = self.proj(x)
x = self.d2st(x)
return x
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks.
temporal_reduction: Temporal reduction factor.
spatial_reduction: Spatial reduction factor.
"""
super().__init__()
layers = []
# Change the channel count in the strided convolution.
# This lets the ResBlock have uniform channel count,
# as in ConvNeXt.
assert in_channels != out_channels
layers.append(
PConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
# First layer in each block always uses replicate padding
padding_mode="replicate",
bias=block_kwargs["bias"],
)
)
for _ in range(num_res_blocks):
layers.append(block_fn(out_channels, **block_kwargs))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
num_freqs = (stop - start) // step
assert inputs.ndim == 5
C = inputs.size(1)
# Create Base 2 Fourier features.
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
assert num_freqs == len(freqs)
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
C = inputs.shape[1]
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w.
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h = w * h
return torch.cat(
[
inputs,
torch.sin(h),
torch.cos(h),
],
dim=1,
)
class FourierFeatures(nn.Module):
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
super().__init__()
self.start = start
self.stop = stop
self.step = step
def forward(self, inputs):
"""Add Fourier features to inputs.
Args:
inputs: Input tensor. Shape: [B, C, T, H, W]
Returns:
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
"""
return add_fourier_features(inputs, self.start, self.stop, self.step)
class Decoder(nn.Module):
def __init__(
self,
*,
out_channels: int = 3,
latent_dim: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
temporal_expansions: Optional[List[int]] = None,
spatial_expansions: Optional[List[int]] = None,
has_attention: List[bool],
output_norm: bool = True,
nonlinearity: str = "silu",
output_nonlinearity: str = "silu",
causal: bool = True,
**block_kwargs,
):
super().__init__()
self.input_channels = latent_dim
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.output_nonlinearity = output_nonlinearity
assert nonlinearity == "silu"
assert causal
ch = [mult * base_channels for mult in channel_multipliers]
self.num_up_blocks = len(ch) - 1
assert len(num_res_blocks) == self.num_up_blocks + 2
blocks = []
first_block = [
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer.
# First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]):
first_block.append(
block_fn(
ch[-1],
has_attention=has_attention[-1],
causal=causal,
**block_kwargs,
)
)
blocks.append(nn.Sequential(*first_block))
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
upsample_block_fn = CausalUpsampleBlock
for i in range(self.num_up_blocks):
block = upsample_block_fn(
ch[-i - 1],
ch[-i - 2],
num_res_blocks=num_res_blocks[-i - 2],
has_attention=has_attention[-i - 2],
temporal_expansion=temporal_expansions[-i - 1],
spatial_expansion=spatial_expansions[-i - 1],
causal=causal,
**block_kwargs,
)
blocks.append(block)
assert not output_norm
# Last block. Preserve channel count.
last_block = []
for _ in range(num_res_blocks[0]):
last_block.append(
block_fn(
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
)
)
blocks.append(nn.Sequential(*last_block))
self.blocks = nn.ModuleList(blocks)
self.output_proj = Conv1x1(ch[0], out_channels)
def forward(self, x):
"""Forward pass.
Args:
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
Returns:
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
T + 1 = (t - 1) * 4.
H = h * 16, W = w * 16.
"""
for block in self.blocks:
x = block(x)
if self.output_nonlinearity == "silu":
x = F.silu(x, inplace=not self.training)
else:
assert (
not self.output_nonlinearity
) # StyleGAN3 omits the to-RGB nonlinearity.
return self.output_proj(x).contiguous()
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
Args:
mean: Mean of the distribution. Shape: [B, C, T, H, W].
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
"""
assert mean.shape == logvar.shape
self.mean = mean
self.logvar = logvar
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
if temperature == 0.0:
return self.mean
if noise is None:
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
else:
assert noise.device == self.mean.device
noise = noise.to(self.mean.dtype)
if temperature != 1.0:
raise NotImplementedError(f"Temperature {temperature} is not supported.")
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(self.logvar * 0.5) + self.mean
def mode(self):
return self.mean
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
):
super().__init__()
self.temporal_reductions = temporal_reductions
self.spatial_reductions = spatial_reductions
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.latent_dim = latent_dim
self.fourier_features = FourierFeatures()
ch = [mult * base_channels for mult in channel_multipliers]
num_down_blocks = len(ch) - 1
assert len(num_res_blocks) == num_down_blocks + 2
layers = (
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
if not input_is_conv_1x1
else [Conv1x1(in_channels, ch[0])]
)
assert len(prune_bottlenecks) == num_down_blocks + 2
assert len(has_attentions) == num_down_blocks + 2
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
for _ in range(num_res_blocks[0]):
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
prune_bottlenecks = prune_bottlenecks[1:]
has_attentions = has_attentions[1:]
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
for i in range(num_down_blocks):
layer = DownsampleBlock(
ch[i],
ch[i + 1],
num_res_blocks=num_res_blocks[i + 1],
temporal_reduction=temporal_reductions[i],
spatial_reduction=spatial_reductions[i],
prune_bottleneck=prune_bottlenecks[i],
has_attention=has_attentions[i],
affine=affine,
bias=bias,
padding_mode=padding_mode,
)
layers.append(layer)
# Additional blocks.
for _ in range(num_res_blocks[-1]):
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
self.layers = nn.Sequential(*layers)
# Output layers.
self.output_norm = norm_fn(ch[-1])
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
@property
def temporal_downsample(self):
return math.prod(self.temporal_reductions)
@property
def spatial_downsample(self):
return math.prod(self.spatial_reductions)
def forward(self, x) -> LatentDistribution:
"""Forward pass.
Args:
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
Returns:
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
logvar: Shape: [B, latent_dim, t, h, w].
"""
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
x = self.fourier_features(x)
x = self.layers(x)
x = self.output_norm(x)
x = F.silu(x, inplace=True)
x = self.output_proj(x)
means, logvar = torch.chunk(x, 2, dim=1)
assert means.ndim == 5
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
return LatentDistribution(means, logvar)
class VideoVAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate"
)
self.decoder = Decoder(
out_channels=3,
base_channels=128,
channel_multipliers=[1, 2, 4, 6],
temporal_expansions=[1, 2, 3],
spatial_expansions=[2, 2, 2],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False,
nonlinearity="silu",
output_nonlinearity="silu",
causal=True,
)
def encode(self, x):
return self.encoder(x).mode()
def decode(self, x):
return self.decoder(x)

View File

@@ -0,0 +1,330 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass
from einops import repeat
from torch import Tensor, nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding
)
import comfy.ldm.common_dit
@dataclass
class HunyuanVideoParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: list
qkv_bias: bool
guidance_embed: bool
class SelfAttentionRef(nn.Module):
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
class TokenRefinerBlock(nn.Module):
def __init__(
self,
hidden_size,
heads,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.heads = heads
mlp_hidden_dim = hidden_size * 4
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
)
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
return x
class IndividualTokenRefiner(nn.Module):
def __init__(
self,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.blocks = nn.ModuleList(
[
TokenRefinerBlock(
hidden_size=hidden_size,
heads=heads,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_blocks)
]
)
def forward(self, x, c, mask):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m)
return x
class TokenRefiner(nn.Module):
def __init__(
self,
text_dim,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
def forward(
self,
x,
timesteps,
mask,
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
return x
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
else:
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((img, txt), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape)
return img
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
return out

View File

@@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
@@ -78,10 +78,9 @@ def apply_rotary_emb(
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xq_out = (xq * cos + rotate_half(xq) * sin)
if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
xk_out = (xk * cos + rotate_half(xk) * sin)
else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]

View File

@@ -0,0 +1,311 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.diffusionmodules.mmdit import (
TimestepEmbedder,
PatchEmbed,
)
from .poolers import AttentionPool
import comfy.latent_formats
from .models import HunYuanDiTBlock, calc_rope
class HunYuanControlNet(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
def __init__(
self,
input_size: tuple = 128,
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1408,
depth: int = 40,
num_heads: int = 16,
mlp_ratio: float = 4.3637,
text_states_dim=1024,
text_states_dim_t5=2048,
text_len=77,
text_len_t5=256,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
size_cond=False,
use_style_cond=False,
learn_sigma=True,
norm="layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = text_states_dim
self.text_states_dim_t5 = text_states_dim_t5
self.text_len = text_len
self.text_len_t5 = text_len_t5
self.size_cond = size_cond
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
self.latent_format = comfy.latent_formats.SDXL
self.mlp_t5 = nn.Sequential(
nn.Linear(
self.text_states_dim_t5,
self.text_states_dim_t5 * 4,
bias=True,
dtype=dtype,
device=device,
),
nn.SiLU(),
nn.Linear(
self.text_states_dim_t5 * 4,
self.text_states_dim,
bias=True,
dtype=dtype,
device=device,
),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.randn(
self.text_len + self.text_len_t5,
self.text_states_dim,
dtype=dtype,
device=device,
)
)
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(
self.text_len_t5,
self.text_states_dim_t5,
num_heads=8,
output_dim=pooler_out_dim,
dtype=dtype,
device=device,
operations=operations,
)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if self.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(
1, hidden_size, dtype=dtype, device=device
)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
hidden_size,
dtype=dtype,
device=device,
operations=operations,
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations
)
self.extra_embedder = nn.Sequential(
operations.Linear(
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
),
nn.SiLU(),
operations.Linear(
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
),
)
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunYuanDiTBlock(
hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
qk_norm=qk_norm,
norm_type=self.norm,
skip=False,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(19)
]
)
# Input zero linear for the first block
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# Output zero linear for the every block
self.after_proj_list = nn.ModuleList(
[
operations.Linear(
self.hidden_size, self.hidden_size, dtype=dtype, device=device
)
for _ in range(len(self.blocks))
]
)
def forward(
self,
x,
hint,
timesteps,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
return_dict=False,
**kwarg,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
condition = hint
if condition.shape[0] == 1:
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
text_states = context # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:, -self.text_len :] = torch.where(
text_states_mask[:, -self.text_len :].unsqueeze(2),
text_states[:, -self.text_len :],
padding[: self.text_len],
)
text_states_t5[:, -self.text_len_t5 :] = torch.where(
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
text_states_t5[:, -self.text_len_t5 :],
padding[self.text_len :],
)
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,2051024
# _, _, oh, ow = x.shape
# th, tw = oh // self.patch_size, ow // self.patch_size
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = calc_rope(
x, self.patch_size, self.hidden_size // self.num_heads
) # (cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(timesteps, dtype=self.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens if applicable
# if image_meta_size is not None:
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
# if image_meta_size.dtype != self.dtype:
# image_meta_size = image_meta_size.half()
# image_meta_size = image_meta_size.view(-1, 6 * 256)
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
if style is not None:
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
# ========================= Deal with Condition =========================
condition = self.x_embedder(condition)
# ========================= Forward pass through HunYuanDiT blocks =========================
controls = []
x = x + self.before_proj(condition) # add condition
for layer, block in enumerate(self.blocks):
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output
return {"output": controls}

View File

@@ -1,8 +1,6 @@
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
@@ -21,6 +19,7 @@ def calc_rope(x, patch_size, head_size):
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
rope = (rope[0].to(x), rope[1].to(x))
return rope
@@ -91,6 +90,8 @@ class HunYuanDiTBlock(nn.Module):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
if cat.dtype != x.dtype:
cat = cat.to(x.dtype)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
@@ -247,9 +248,6 @@ class HunYuanDiT(nn.Module):
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
@@ -284,7 +282,7 @@ class HunYuanDiT(nn.Module):
style=None,
return_dict=False,
control=None,
transformer_options=None,
transformer_options={},
):
"""
Forward pass of the encoder.
@@ -312,8 +310,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
patches_replace = transformer_options.get("patches_replace", {})
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@@ -361,18 +358,33 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None
if control:
controls = control.get("output", None)
# ========================= Forward pass through HunYuanDiT blocks =========================
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)

View File

@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops

View File

@@ -0,0 +1,527 @@
import torch
from torch import nn
import comfy.ldm.modules.attention
from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None, device=None, operations=None,
):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
if cond_proj_dim is not None:
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
else:
self.cond_proj = None
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
if post_act_fn is None:
self.post_act = None
# else:
# self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
return timesteps_emb
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
)
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
def forward(self, x):
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
super().__init__()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.attn_precision = attn_precision
self.heads = heads
self.dim_head = dim_head
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
return x
def get_fractional_positions(indices_grid, max_pos):
fractional_positions = torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype
fractional_positions = get_fractional_positions(indices_grid, max_pos)
start = 1
end = theta
device = fractional_positions.device
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
class LTXVModel(torch.nn.Module):
def __init__(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
)
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
# attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for d in range(num_layers)
]
)
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = self.proj_out(x)
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],
output_width=orig_shape[4],
output_num_frames=orig_shape[2],
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from einops import rearrange
from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
class Patchifier(ABC):
def __init__(self, patch_size: int):
super().__init__()
self._patch_size = (1, patch_size, patch_size)
@abstractmethod
def patchify(
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
) -> Tuple[Tensor, Tensor]:
pass
@abstractmethod
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
pass
@property
def patch_size(self):
return self._patch_size
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
):
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w)
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
class SymmetricPatchifier(Patchifier):
def patchify(
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2]
latents = rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q) ",
f=output_num_frames,
h=output_height,
w=output_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents

View File

@@ -0,0 +1,64 @@
from typing import Tuple, Union
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = ops.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
groups=groups,
)
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):
return self.conv.weight

View File

@@ -0,0 +1,907 @@
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
ops = comfy.ops.disable_weight_init
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
):
super().__init__()
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.blocks_desc = blocks
in_channels = in_channels * patch_size**2
output_channel = base_channels
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.down_blocks = nn.ModuleList([])
for block_name, block_params in blocks:
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "compress_time":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 1, 1),
causal=True,
)
elif block_name == "compress_space":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(1, 2, 2),
causal=True,
)
elif block_name == "compress_all":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
else:
raise ValueError(f"unknown block: {block_name}")
self.down_blocks.append(block)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
causal (`bool`, *optional*, defaults to `True`):
Whether to use causal convolutions or not.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
):
super().__init__()
self.patch_size = patch_size
self.layers_per_block = layers_per_block
out_channels = out_channels * patch_size**2
self.causal = causal
self.blocks_desc = blocks
# Compute output channel to be product of all channel-multiplier blocks
output_channel = base_channels
for block_name, block_params in list(reversed(blocks)):
block_params = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
in_channels,
output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.up_blocks = nn.ModuleList([])
for block_name, block_params in list(reversed(blocks)):
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
)
else:
raise ValueError(f"unknown layer: {block_name}")
self.up_blocks.append(block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(
torch.tensor(1000.0, dtype=torch.float32)
)
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
sample = self.conv_in(sample, causal=self.causal)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=sample.shape[0],
hidden_dtype=sample.dtype,
)
embedded_timestep = embedded_timestep.view(
batch_size, embedded_timestep.shape[-1], 1, 1, 1
)
ada_values = self.last_scale_shift_table[
None, ..., None, None, None
] + embedded_timestep.reshape(
batch_size,
2,
-1,
embedded_timestep.shape[-3],
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
in_channels * 4, 0, operations=ops,
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
batch_size = hidden_states.shape[0]
timestep_embed = self.time_embedder(
timestep=timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
timestep_embed = timestep_embed.view(
batch_size, timestep_embed.shape[-1], 1, 1, 1
)
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
return hidden_states
class DepthToSpaceUpsample(nn.Module):
def __init__(
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
):
super().__init__()
self.stride = stride
self.out_channels = (
math.prod(stride) * in_channels // out_channels_reduction_factor
)
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
causal=True,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
if norm_layer == "group_norm":
self.norm1 = nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm1 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
if norm_layer == "group_norm":
self.norm2 = nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm2 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
if inject_noise:
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
)
if in_channels != out_channels
else nn.Identity()
)
self.norm3 = (
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
if in_channels != out_channels
else nn.Identity()
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
spatial_shape = hidden_states.shape[-2:]
device = hidden_states.device
dtype = hidden_states.dtype
# similar to the "explicit noise inputs" method in style-gan
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
hidden_states = hidden_states + scaled_noise
return hidden_states
def forward(
self,
input_tensor: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
hidden_states = input_tensor
batch_size = hidden_states.shape[0]
hidden_states = self.norm1(hidden_states)
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
ada_values = self.scale_shift_table[
None, ..., None, None, None
] + timestep.reshape(
batch_size,
4,
-1,
timestep.shape[-3],
timestep.shape[-2],
timestep.shape[-1],
)
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
hidden_states = hidden_states * (1 + scale1) + shift1
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale1
)
hidden_states = self.norm2(hidden_states)
if self.timestep_conditioning:
hidden_states = hidden_states * (1 + scale2) + shift2
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale2
)
input_tensor = self.norm3(input_tensor)
batch_size = input_tensor.shape[0]
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
class processor(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self, version=0):
super().__init__()
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
else:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"decoder_blocks": [
["res_x", {"num_layers": 5, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 6, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 7, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 8, "inject_noise": False}]
],
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_all", {}],
["res_x_y", 1],
["res_x", {"num_layers": 3}],
["compress_all", {}],
["res_x_y", 1],
["res_x", {"num_layers": 3}],
["compress_all", {}],
["res_x", {"num_layers": 3}],
["res_x", {"num_layers": 4}]
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
)
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.per_channel_statistics = processor()
def encode(self, x):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x, timestep=0.05, noise_scale=0.025):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)

View File

@@ -0,0 +1,82 @@
from typing import Tuple, Union
from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd(
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
causal=False,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == 3:
if causal:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
return ops.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == (2, 1):
return DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def make_linear_nd(
dims: int,
in_channels: int,
out_channels: int,
bias=True,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
elif dims == 3 or dims == (2, 1):
return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@@ -0,0 +1,195 @@
import math
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class DualConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if kernel_size == (1, 1, 1):
raise ValueError(
"kernel_size must be greater than 1. Use make_linear_nd instead."
)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
# Set parameters for convolutions
self.groups = groups
self.bias = bias
# Define the size of the channels after the first convolution
intermediate_channels = (
out_channels if in_channels < out_channels else in_channels
)
# Define parameters for the first convolution
self.weight1 = nn.Parameter(
torch.Tensor(
intermediate_channels,
in_channels // groups,
1,
kernel_size[1],
kernel_size[2],
)
)
self.stride1 = (1, stride[1], stride[2])
self.padding1 = (0, padding[1], padding[2])
self.dilation1 = (1, dilation[1], dilation[2])
if bias:
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
else:
self.register_parameter("bias1", None)
# Define parameters for the second convolution
self.weight2 = nn.Parameter(
torch.Tensor(
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
)
)
self.stride2 = (stride[0], 1, 1)
self.padding2 = (padding[0], 0, 0)
self.dilation2 = (dilation[0], 1, 1)
if bias:
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias2", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
if self.bias:
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
bound1 = 1 / math.sqrt(fan_in1)
nn.init.uniform_(self.bias1, -bound1, bound1)
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
bound2 = 1 / math.sqrt(fan_in2)
nn.init.uniform_(self.bias2, -bound2, bound2)
def forward(self, x, use_conv3d=False, skip_time_conv=False):
if use_conv3d:
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
else:
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
def forward_with_3d(self, x, skip_time_conv):
# First convolution
x = F.conv3d(
x,
self.weight1,
self.bias1,
self.stride1,
self.padding1,
self.dilation1,
self.groups,
)
if skip_time_conv:
return x
# Second convolution
x = F.conv3d(
x,
self.weight2,
self.bias2,
self.stride2,
self.padding2,
self.dilation2,
self.groups,
)
return x
def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape
# First 2D convolution
x = rearrange(x, "b c d h w -> (b d) c h w")
# Squeeze the depth dimension out of weight1 since it's 1
weight1 = self.weight1.squeeze(2)
# Select stride, padding, and dilation for the 2D convolution
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
# Reshape weight2 to match the expected dimensions for conv1d
weight2 = self.weight2.squeeze(-1).squeeze(-1)
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self):
return self.weight2
def test_dual_conv3d_consistency():
# Initialize parameters
in_channels = 3
out_channels = 5
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
# Create an instance of the DualConv3d class
dual_conv3d = DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=True,
)
# Example input tensor
test_input = torch.randn(1, 3, 10, 10, 10)
# Perform forward passes with both 3D and 2D settings
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
output_2d = dual_conv3d(test_input, use_conv3d=False)
# Assert that the outputs from both methods are sufficiently close
assert torch.allclose(
output_conv3d, output_2d, atol=1e-6
), "Outputs are not consistent between 3D and 2D convolutions."

View File

@@ -0,0 +1,12 @@
import torch
from torch import nn
class PixelNorm(nn.Module):
def __init__(self, dim=1, eps=1e-8):
super(PixelNorm, self).__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)

View File

@@ -1,10 +1,12 @@
import logging
import math
import torch
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma
import comfy.ops
@@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
def get_input(self, batch) -> Any:
raise NotImplementedError()
@@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logpy.info(f"{context}: Switched to EMA weights")
logging.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logpy.info(f"{context}: Restored training weights")
logging.info(f"{context}: Restored training weights")
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
@@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
self.regularization = instantiate_from_config(
regularizer_config
)
@@ -160,12 +162,19 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
if ddconfig.get("conv3d", False):
conv_op = comfy.ops.disable_weight_init.Conv3d
else:
conv_op = comfy.ops.disable_weight_init.Conv2d
self.quant_conv = conv_op(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:

View File

@@ -15,6 +15,9 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
if model_management.sage_attention_enabled():
from sageattention import sageattn
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -157,8 +160,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
scale = dim_head ** -0.5
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@@ -177,9 +178,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
mem_free_total, _ = model_management.get_free_memory(query.device, True)
kv_chunk_size_min = None
kv_chunk_size = None
@@ -230,7 +230,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
scale = dim_head ** -0.5
h = heads
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
@@ -299,7 +298,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
if mask.shape[1] == 1:
s1 += mask
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
@@ -341,12 +343,9 @@ except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
disabled_xformers = False
if BROKEN_XFORMERS:
@@ -358,41 +357,57 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
# b h k d -> b k h d
q, k, v = map(
lambda t: t.permute(0, 2, 1, 3),
(q, k, v),
)
# actually do the reshaping
else:
dim_head //= heads
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
# add a singleton batch dimension
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a singleton heads dimension
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# pad to a multiple of 8
pad = 8 - mask.shape[-1] % 8
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
# in flux, this matrix ends up being over 1GB
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
# doesn't this remove the padding again??
mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
out = (
out.reshape(b, -1, heads * dim_head)
)
return out
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2**15
else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
@@ -404,27 +419,85 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT):
m = mask
if mask is not None:
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
optimized_attention = attention_basic
if model_management.xformers_enabled():
logging.info("Using xformers cross attention")
if model_management.sage_attention_enabled():
logging.info("Using sage attention")
optimized_attention = attention_sage
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch cross attention")
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
logging.info("Using split optimization for cross attention")
logging.info("Using split optimization for attention")
optimized_attention = attention_split
else:
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention

View File

@@ -1,11 +1,10 @@
import logging
import math
from typing import Dict, Optional
from functools import partial
from typing import Dict, Optional, List
import numpy as np
import torch
import torch.nn as nn
from .. import attention
from ..attention import optimized_attention
from einops import rearrange, repeat
from .util import timestep_embedding
import comfy.ops
@@ -72,45 +71,33 @@ class PatchEmbed(nn.Module):
strict_img_size: bool = True,
dynamic_img_pad: bool = True,
padding_mode='circular',
conv3d=False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.patch_size = (patch_size, patch_size)
try:
len(patch_size)
self.patch_size = patch_size
except:
if conv3d:
self.patch_size = (patch_size, patch_size, patch_size)
else:
self.patch_size = (patch_size, patch_size)
self.padding_mode = padding_mode
if img_size is not None:
self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
if conv3d:
self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
else:
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# if self.img_size is not None:
# if self.strict_img_size:
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
# elif not self.dynamic_img_pad:
# _assert(
# H % self.patch_size[0] == 0,
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
# )
# _assert(
# W % self.patch_size[1] == 0,
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# )
if self.dynamic_img_pad:
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
x = self.proj(x)
@@ -266,8 +253,6 @@ def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]
def optimized_attention(qkv, num_heads):
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
@@ -326,9 +311,9 @@ class SelfAttention(nn.Module):
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.pre_attention(x)
q, k, v = self.pre_attention(x)
x = optimized_attention(
qkv, num_heads=self.num_heads
q, k, v, heads=self.num_heads
)
x = self.post_attention(x)
return x
@@ -355,29 +340,9 @@ class RMSNorm(torch.nn.Module):
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
class SwiGLUFeedForward(nn.Module):
@@ -437,6 +402,7 @@ class DismantledBlock(nn.Module):
scale_mod_only: bool = False,
swiglu: bool = False,
qk_norm: Optional[str] = None,
x_block_self_attn: bool = False,
dtype=None,
device=None,
operations=None,
@@ -460,6 +426,24 @@ class DismantledBlock(nn.Module):
device=device,
operations=operations
)
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
self.x_block_self_attn = True
self.attn2 = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=False,
qk_norm=qk_norm,
rmsnorm=rmsnorm,
dtype=dtype,
device=device,
operations=operations
)
else:
self.x_block_self_attn = False
if not pre_only:
if not rmsnorm:
self.norm2 = operations.LayerNorm(
@@ -486,7 +470,11 @@ class DismantledBlock(nn.Module):
multiple_of=256,
)
self.scale_mod_only = scale_mod_only
if not scale_mod_only:
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
n_mods = 9
elif not scale_mod_only:
n_mods = 6 if not pre_only else 2
else:
n_mods = 4 if not pre_only else 1
@@ -547,14 +535,64 @@ class DismantledBlock(nn.Module):
)
return x
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert self.x_block_self_attn
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = self.adaLN_modulation(c).chunk(9, dim=1)
x_norm = self.norm1(x)
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
return qkv, qkv2, (
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
)
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv,
num_heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)
if self.x_block_self_attn:
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
attn, _ = optimized_attention(
qkv[0], qkv[1], qkv[2],
num_heads=self.attn.num_heads,
)
attn2, _ = optimized_attention(
qkv2[0], qkv2[1], qkv2[2],
num_heads=self.attn2.num_heads,
)
return self.post_attention_x(attn, attn2, *intermediates)
else:
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)
def block_mixing(*args, use_checkpoint=True, **kwargs):
@@ -569,7 +607,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
def _block_mixing(context, x, context_block, x_block, c):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
x_qkv, x_intermediates = x_block.pre_attention(x, c)
if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c)
o = []
for t in range(3):
@@ -577,8 +618,8 @@ def _block_mixing(context, x, context_block, x_block, c):
qkv = tuple(o)
attn = optimized_attention(
qkv,
num_heads=x_block.attn.num_heads,
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@@ -590,7 +631,14 @@ def _block_mixing(context, x, context_block, x_block, c):
else:
context = None
x = x_block.post_attention(x_attn, *x_intermediates)
if x_block.x_block_self_attn:
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
@@ -605,8 +653,13 @@ class JointBlock(nn.Module):
super().__init__()
pre_only = kwargs.pop("pre_only")
qk_norm = kwargs.pop("qk_norm", None)
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args,
pre_only=False,
qk_norm=qk_norm,
x_block_self_attn=x_block_self_attn,
**kwargs)
def forward(self, *args, **kwargs):
return block_mixing(
@@ -662,7 +715,7 @@ class SelfAttentionContext(nn.Module):
def forward(self, x):
qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.dim_head)
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
return self.proj(x)
class ContextProcessorBlock(nn.Module):
@@ -721,9 +774,12 @@ class MMDiT(nn.Module):
qk_norm: Optional[str] = None,
qkv_bias: bool = True,
context_processor_layers = None,
x_block_self_attn: bool = False,
x_block_self_attn_layers: Optional[List[int]] = [],
context_size = 4096,
num_blocks = None,
final_layer = True,
skip_blocks = False,
dtype = None, #TODO
device = None,
operations = None,
@@ -738,6 +794,7 @@ class MMDiT(nn.Module):
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size
self.x_block_self_attn_layers = x_block_self_attn_layers
# hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64)
@@ -795,26 +852,28 @@ class MMDiT(nn.Module):
self.pos_embed = None
self.use_checkpoint = use_checkpoint
self.joint_blocks = nn.ModuleList(
[
JointBlock(
self.hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=(i == num_blocks - 1) and final_layer,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations
)
for i in range(num_blocks)
]
)
if not skip_blocks:
self.joint_blocks = nn.ModuleList(
[
JointBlock(
self.hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=(i == num_blocks - 1) and final_layer,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(num_blocks)
]
)
if final_layer:
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
@@ -877,7 +936,9 @@ class MMDiT(nn.Module):
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if self.register_length > 0:
context = torch.cat(
(
@@ -889,14 +950,25 @@ class MMDiT(nn.Module):
# context is B, L', D
# x is B, L, D
blocks_replace = patches_replace.get("dit", {})
blocks = len(self.joint_blocks)
for i in range(blocks):
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if control is not None:
control_o = control.get("output")
if i < len(control_o):
@@ -914,6 +986,7 @@ class MMDiT(nn.Module):
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
"""
Forward pass of DiT.
@@ -935,7 +1008,7 @@ class MMDiT(nn.Module):
if context is not None:
context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context, control)
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]]
@@ -949,7 +1022,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control)
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)

View File

@@ -3,7 +3,6 @@ import math
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Any
import logging
from comfy import model_management
@@ -44,51 +43,100 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
self.padding_mode = padding_mode
if padding != 0:
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
else:
kwargs["padding"] = padding
self.padding = padding
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
if self.padding != 0:
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
return self.conv(x)
def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
super().__init__()
self.with_conv = with_conv
self.scale_factor = scale_factor
if self.with_conv:
self.conv = ops.Conv2d(in_channels,
self.conv = conv_op(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2]
if t > 1:
a, b = x.split((1, t - 1), dim=2)
del x
b = interpolate_up(b, scale_factor)
else:
a = x
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = ops.Conv2d(in_channels,
self.conv = conv_op(in_channels,
in_channels,
kernel_size=3,
stride=2,
stride=stride,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
if x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@@ -97,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
dropout, temb_channels=512, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -106,7 +154,7 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = ops.Conv2d(in_channels,
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
stride=1,
@@ -116,20 +164,20 @@ class ResnetBlock(nn.Module):
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = ops.Conv2d(out_channels,
self.conv2 = conv_op(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv2d(in_channels,
self.conv_shortcut = conv_op(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = ops.Conv2d(in_channels,
self.nin_shortcut = conv_op(in_channels,
out_channels,
kernel_size=1,
stride=1,
@@ -163,7 +211,6 @@ def slice_attention(q, k, v):
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
@@ -196,21 +243,25 @@ def slice_attention(q, k, v):
def normal_attention(q, k, v):
# compute attention
b,c,h,w = q.shape
orig_shape = q.shape
b = orig_shape[0]
c = orig_shape[1]
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
q = q.reshape(b, c, -1)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, -1) # b,c,hw
v = v.reshape(b, c, -1)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
h_ = r1.reshape(orig_shape)
del r1
return h_
def xformers_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
@@ -218,14 +269,16 @@ def xformers_attention(q, k, v):
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = out.transpose(1, 2).reshape(orig_shape)
except NotImplementedError:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
def pytorch_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@@ -233,35 +286,35 @@ def pytorch_attention(q, k, v):
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
class AttnBlock(nn.Module):
def __init__(self, in_channels):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = ops.Conv2d(in_channels,
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = ops.Conv2d(in_channels,
self.k = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = ops.Conv2d(in_channels,
self.v = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = ops.Conv2d(in_channels,
self.proj_out = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
@@ -291,8 +344,8 @@ class AttnBlock(nn.Module):
return x+h_
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
return AttnBlock(in_channels)
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
return AttnBlock(in_channels, conv_op=conv_op)
class Model(nn.Module):
@@ -451,6 +504,7 @@ class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
conv3d=False, time_compress=None,
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
@@ -461,8 +515,15 @@ class Encoder(nn.Module):
self.resolution = resolution
self.in_channels = in_channels
if conv3d:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# downsampling
self.conv_in = ops.Conv2d(in_channels,
self.conv_in = conv_op(in_channels,
self.ch,
kernel_size=3,
stride=1,
@@ -481,15 +542,20 @@ class Encoder(nn.Module):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
dropout=dropout,
conv_op=conv_op))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
stride = 2
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
@@ -498,16 +564,18 @@ class Encoder(nn.Module):
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
dropout=dropout,
conv_op=conv_op)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
dropout=dropout,
conv_op=conv_op)
# end
self.norm_out = Normalize(block_in)
self.conv_out = ops.Conv2d(block_in,
self.conv_out = conv_op(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
@@ -545,9 +613,10 @@ class Decoder(nn.Module):
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
conv3d=False,
time_compress=None,
**ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
@@ -557,8 +626,15 @@ class Decoder(nn.Module):
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
if conv3d:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# compute block_in and curr_res at lowest res
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
@@ -566,7 +642,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = ops.Conv2d(z_channels,
self.conv_in = conv_op(z_channels,
block_in,
kernel_size=3,
stride=1,
@@ -577,12 +653,14 @@ class Decoder(nn.Module):
self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = attn_op(block_in)
dropout=dropout,
conv_op=conv_op)
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
dropout=dropout,
conv_op=conv_op)
# upsampling
self.up = nn.ModuleList()
@@ -594,15 +672,21 @@ class Decoder(nn.Module):
block.append(resnet_op(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
dropout=dropout,
conv_op=conv_op))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(attn_op(block_in))
attn.append(attn_op(block_in, conv_op=conv_op))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
scale_factor = 2.0
if time_compress is not None:
if i_level > math.log2(time_compress):
scale_factor = (1.0, 2.0, 2.0)
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

View File

@@ -9,12 +9,12 @@ import logging
from .util import (
checkpoint,
avg_pool_nd,
zero_module,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.patcher_extension
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -47,6 +47,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x)
return x
@@ -819,6 +828,13 @@ class UNetModel(nn.Module):
)
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@@ -842,6 +858,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

View File

@@ -4,7 +4,6 @@ import numpy as np
from functools import partial
from .util import extract_into_tensor, make_beta_schedule
from comfy.ldm.util import default
class AbstractLowScaleModel(nn.Module):

View File

@@ -8,7 +8,6 @@
# thanks!
import os
import math
import torch
import torch.nn as nn

View File

@@ -30,10 +30,10 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None):

View File

@@ -22,7 +22,6 @@ except ImportError:
from typing import Optional, NamedTuple, List
from typing_extensions import Protocol
from torch import Tensor
from typing import List
from comfy import model_management
@@ -172,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
del attn_scores
except model_management.OOM_EXCEPTION:
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
attn_scores /= summed
@@ -234,6 +233,8 @@ def efficient_dot_product_attention(
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
if mask.shape[1] == 1:
return mask
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]

View File

@@ -1,5 +1,5 @@
import functools
from typing import Callable, Iterable, Union
from typing import Iterable, Union
import torch
from einops import rearrange, repeat
@@ -194,6 +194,7 @@ def make_time_attn(
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
conv_op=ops.Conv2d,
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy

View File

@@ -133,7 +133,6 @@ class AdamWwithEMAandWings(optim.Optimizer):
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']

View File

@@ -1,5 +1,27 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import comfy.utils
import comfy.model_management
import comfy.model_base
import logging
import torch
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@@ -11,7 +33,7 @@ LORA_CLIP_MAP = {
}
def load_lora(lora, to_load):
def load_lora(lora, to_load, log_missing=True):
patch_dict = {}
loaded_keys = set()
for x in to_load:
@@ -27,10 +49,20 @@ def load_lora(lora, to_load):
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
reshape_name = "{}.reshape_weight".format(x)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
@@ -50,6 +82,10 @@ def load_lora(lora, to_load):
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
@@ -60,7 +96,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@@ -171,17 +207,28 @@ def load_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
set_weight_name = "{}.set_weight".format(x)
set_weight = lora.get(set_weight_name, None)
if set_weight is not None:
patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name)
if log_missing:
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
return patch_dict
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
clip_g_present = False
for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
@@ -205,6 +252,7 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
clip_g_present = True
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
@@ -218,11 +266,25 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k
for k in sdk: #OneTrainer SD3 lora
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
for k in sdk:
if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
l_key = k[len("t5xxl.transformer."):-len(".weight")]
t5_index = 1
if clip_g_present:
t5_index += 1
if clip_l_present:
t5_index += 1
if t5_index == 2:
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
t5_index += 1
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
@@ -241,10 +303,14 @@ def model_lora_keys_unet(model, key_map={}):
sdk = sd.keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
@@ -252,6 +318,7 @@ def model_lora_keys_unet(model, key_map={}):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
@@ -274,6 +341,10 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
key_map[key_lora] = to
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
@@ -293,7 +364,286 @@ def model_lora_keys_unet(model, key_map={}):
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.HunyuanVideo):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
# diffusion-pipe lora format
key_lora = k
key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.")
key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.")
key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.")
key_lora = key_lora[len("diffusion_model."):-len(".weight")]
key_map["transformer.{}".format(key_lora)] = k
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
Args:
tensor (torch.Tensor): The original tensor to be padded.
new_shape (List[int]): The desired shape of the padded tensor.
Returns:
torch.Tensor: A new tensor padded with zeros to the specified shape.
Note:
If the new shape is smaller than the original tensor in any dimension,
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
# Create slicing tuples for both tensors
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
# Copy the original tensor into the new tensor
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
diff: torch.Tensor = v[0]
# An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
if do_pad_weight and diff.shape != weight.shape:
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
weight = pad_tensor_to_shape(weight, diff.shape)
if strength != 0.0:
if diff.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set":
weight.copy_(v[0])
elif patch_type == "model_as_lora":
target_weight: torch.Tensor = v[0]
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
reshape = v[5]
if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight

17
comfy/lora_convert.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
return sd

Some files were not shown because too many files have changed in this diff Show More