Compare commits

...

102 Commits

Author SHA1 Message Date
comfyanonymous
55add50220 Bump ComfyUI version to v0.3.12 2025-01-16 18:11:57 -05:00
comfyanonymous
0aa2368e46 Fix some cosmos fp8 issues. 2025-01-16 17:45:37 -05:00
comfyanonymous
cca96a85ae Fix cosmos VAE failing with videos longer than 121 frames. 2025-01-16 16:30:06 -05:00
comfyanonymous
619b8cde74 Bump ComfyUI version to 0.3.11 2025-01-16 14:54:48 -05:00
comfyanonymous
31831e6ef1 Code refactor. 2025-01-16 07:23:54 -05:00
comfyanonymous
88ceb28e20 Tweak hunyuan memory usage factor. 2025-01-16 06:31:03 -05:00
comfyanonymous
23289a6a5c Clean up some debug lines. 2025-01-16 04:24:39 -05:00
comfyanonymous
9d8b6c1f46 More accurate memory estimation for cosmos and hunyuan video. 2025-01-16 03:48:40 -05:00
comfyanonymous
6320d05696 Slightly lower hunyuan video memory usage. 2025-01-16 00:23:01 -05:00
comfyanonymous
25683b5b02 Lower cosmos diffusion model memory usage. 2025-01-15 23:46:42 -05:00
comfyanonymous
4758fb64b9 Lower cosmos VAE memory usage by a bit. 2025-01-15 22:57:52 -05:00
comfyanonymous
008761166f Optimize first attention block in cosmos VAE. 2025-01-15 21:48:46 -05:00
comfyanonymous
bfd5dfd611 3.13 doesn't work yet. 2025-01-15 20:32:44 -05:00
comfyanonymous
55ade36d01 Remove python 3.8 from test-build workflow. 2025-01-15 20:24:55 -05:00
comfyanonymous
2e20e399ea Add minimum numpy version to requirements.txt 2025-01-15 20:19:56 -05:00
comfyanonymous
3baf92d120 CosmosImageToVideoLatent batch_size now does something. 2025-01-15 17:19:59 -05:00
comfyanonymous
1709a8441e Use latest python 3.12.8 the portable release. 2025-01-15 14:50:40 -05:00
comfyanonymous
cba58fff0b Remove unsafe embedding load for very old pytorch. 2025-01-15 04:32:23 -05:00
comfyanonymous
2feb8d0b77 Force safe loading of files in torch format on pytorch 2.4+
If this breaks something for you make an issue.
2025-01-15 03:50:27 -05:00
comfyanonymous
5b657f8c15 Allow setting start and end image in CosmosImageToVideoLatent. 2025-01-15 00:41:35 -05:00
catboxanon
2cdbaf5169 Add SetFirstSigma node (#6459)
Useful for models utilizing ztSNR. See: https://arxiv.org/abs/2409.15997
2025-01-14 19:05:45 -05:00
Pam
c78a45685d Rewrite res_multistep sampler and implement res_multistep_cfg_pp sampler. (#6462) 2025-01-14 18:20:06 -05:00
comfyanonymous
3aaabb12d4 Implement Cosmos Image/Video to World (Video) diffusion models.
Use CosmosImageToVideoLatent to set the input image/video.
2025-01-14 05:14:10 -05:00
comfyanonymous
1f1c7b7b56 Remove useless code. 2025-01-13 03:52:37 -05:00
comfyanonymous
90f349f93d Add res_multistep sampler from the cosmos code.
This sampler should work with all models.
2025-01-12 03:10:07 -05:00
Alexander Piskun
b9d9bcba14 fixed a bug where a relative path was not converted to a full path (#6395)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2025-01-11 19:19:51 -05:00
Chenlei Hu
42086af123 Merge ruff.toml into pyproject.toml (#6431) 2025-01-11 12:52:46 -05:00
Jedrzej Kosinski
6c9bd11fa3 Hooks Part 2 - TransformerOptionsHook and AdditionalModelsHook (#6377)
* Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas'

* Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed

* Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options)

* Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type

* In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch

* Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time

* Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational

* Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time)

* Filter only registered hooks on self.conds in CFGGuider.sample

* Make hook_scope functional for TransformerOptionsHook

* removed 4 whitespace lines to satisfy Ruff,

* Add a get_injections function to ModelPatcher

* Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable

* Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out)

* Clean up a typehint
2025-01-11 12:20:23 -05:00
comfyanonymous
ee8a7ab69d Fast latent preview for Cosmos. 2025-01-11 04:41:24 -05:00
Chenlei Hu
9c773a241b Add pyproject.toml (#6386)
* Add pyproject.toml

* doc

* Static version file

* Add github action to sync version.py

* Change trigger to PR

* Fix commit

* Grant pr write permission

* nit

* nit

* Don't run on fork PRs

* Rename version.py to comfyui_version.py
2025-01-11 03:09:25 -05:00
comfyanonymous
adea2beb5c Add edm option to ModelSamplingContinuousEDM for Cosmos.
You can now use this node with "edm" selected to control the sigma_max and
sigma_min of the Cosmos model sampling.
2025-01-11 02:18:42 -05:00
comfyanonymous
2ff3104f70 WIP support for Nvidia Cosmos 7B and 14B text to world (video) models. 2025-01-10 09:14:16 -05:00
comfyanonymous
129d8908f7 Add argument to skip the output reshaping in the attention functions. 2025-01-10 06:27:37 -05:00
comfyanonymous
ff838657fa Cleaner handling of attention mask in ltxv model code. 2025-01-09 07:12:03 -05:00
comfyanonymous
2307ff6746 Improve some logging messages. 2025-01-08 19:05:22 -05:00
comfyanonymous
d0f3752e33 Properly calculate inner dim for t5 model.
This is required to support some different types of t5 models.
2025-01-07 17:33:03 -05:00
Dr.Lt.Data
c515bdf371 fixed: robust loading comfy.settings.json (#6383)
https://github.com/comfyanonymous/ComfyUI/issues/6371
2025-01-07 16:03:56 -05:00
comfyanonymous
4209edf48d Make a few more samplers deterministic. 2025-01-07 02:12:32 -05:00
Chenlei Hu
d055325783 Document get_attr and get_model_object (#6357)
* Document get_attr and get_model_object

* Update model_patcher.py

* Update model_patcher.py

* Update model_patcher.py
2025-01-06 20:12:22 -05:00
Chenlei Hu
eeab420c70 Update frontend to v1.6.18 (#6368) 2025-01-06 18:42:45 -05:00
comfyanonymous
916d1e14a9 Make ancestral samplers more deterministic. 2025-01-06 03:04:32 -05:00
Jedrzej Kosinski
c496e53519 In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch (#6360) 2025-01-06 01:36:47 -05:00
Yoland Yan
7da85fac3f Update CODEOWNERS (#6338)
Adding yoland and robin to web dir
2025-01-05 04:33:49 -05:00
Chenlei Hu
b65b83af6f Add update-frontend github action (#6336)
* Add update-frontend github action

* Update secrets

* nit
2025-01-05 04:32:11 -05:00
comfyanonymous
c8a3492c22 Make the device an optional parameter in the clip loaders. 2025-01-05 04:29:36 -05:00
comfyanonymous
5cbf79787f Add advanced device option to clip loader nodes.
Right click the "Load CLIP" or DualCLIPLoader node and "Show Advanced".
2025-01-05 01:46:11 -05:00
comfyanonymous
d45ebb63f6 Remove old unused function. 2025-01-04 07:20:54 -05:00
Chenlei Hu
caa6476a69 Update web content to release v1.6.17 (#6337)
* Update web content to release v1.6.17

* Remove js maps
2025-01-03 16:22:08 -05:00
Chenlei Hu
45671cda0b Update web content to release v1.6.16 (#6335)
* Update web content to release v1.6.16
2025-01-03 13:56:46 -05:00
comfyanonymous
8f29664057 Change defaults in nightly package workflow. 2025-01-03 12:12:17 -05:00
Chenlei Hu
0b9839ef43 Update web content to release v1.6.15 (#6324) 2025-01-02 19:20:48 -05:00
Terry Jia
953693b137 add fov and mask for load 3d node (#6308)
* add fov and mask for load 3d node

* some comments
2025-01-02 19:20:34 -05:00
Chenlei Hu
a39ea87bca Update web content to release v1.6.14 (#6312) 2025-01-02 16:18:54 -05:00
comfyanonymous
9e9c8a1c64 Clear cache as often on AMD as Nvidia.
I think the issue this was working around has been solved.

If you notice that this change slows things down or causes stutters on
your AMD GPU with ROCm on Linux please report it.
2025-01-02 08:44:16 -05:00
Andrew Kvochko
0f11d60afb Fix temporal tiling for decoder, remove redundant tiles. (#6306)
This commit fixes the temporal tile size calculation, and removes
a redundant tile at the end of the range when its elements are
completely covered by the previous tile.

Co-authored-by: Andrew Kvochko <a.kvochko@lightricks.com>
2025-01-01 16:29:01 -05:00
comfyanonymous
79eea51a1d Fix and enforce all ruff W rules. 2025-01-01 03:08:33 -05:00
blepping
c0338a46a4 Fix unknown sampler error handling in calculate_sigmas function (#6280)
Modernize calculate_sigmas function
2024-12-31 17:33:50 -05:00
Jedrzej Kosinski
1c99734e5a Add missing model_options param (#6296) 2024-12-31 14:46:55 -05:00
filtered
67758f50f3 Fix custom node type-hinting examples (#6281)
* Fix import in comfy_types doc / sample

* Clarify docstring
2024-12-31 03:41:09 -05:00
Alexander Piskun
02eef72bf5 fixed "verbose" argument (#6289)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-31 03:27:09 -05:00
comfyanonymous
b7572b2f87 Fix and enforce no trailing whitespace. 2024-12-31 03:16:37 -05:00
blepping
a90aafafc1 Add kl_optimal scheduler (#6206)
* Add kl_optimal scheduler

* Rename kl_optimal_schedule to kl_optimal_scheduler to be more consistent
2024-12-30 05:09:38 -05:00
comfyanonymous
d9b7cfac7e Fix and enforce new lines at the end of files. 2024-12-30 04:14:59 -05:00
Jedrzej Kosinski
3507870535 Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' (#6273) 2024-12-30 03:42:49 -05:00
catboxanon
82ecb02c1e Remove duplicate calls to INPUT_TYPES (#6249) 2024-12-29 20:06:49 -05:00
comfyanonymous
a618f768e0 Auto reshape 2d to 3d latent for single image generation on video model. 2024-12-29 02:26:49 -05:00
comfyanonymous
e1dec3c792 Fix formatting. 2024-12-28 05:33:17 -05:00
Zoltán Dócs
96697c4bc5 serve workflow templates from custom_nodes (#6193)
* add GET /workflow_templates

* serve workflow templates from custom_nodes

* refactor into custom_node_manager, add test

* remove unused import

* revert changes in folder_paths

* Remove trailing whitespace.

* account for multiple custom_nodes paths
2024-12-28 05:30:04 -05:00
comfyanonymous
b504bd606d Add ruff rule for empty line with trailing whitespace. 2024-12-28 05:23:08 -05:00
comfyanonymous
d170292594 Remove some trailing white space. 2024-12-27 18:02:30 -05:00
filtered
9cfd185676 Add option to log non-error output to stdout (#6243)
* nit

* Add option to log non-error output to stdout

- No change to default behaviour
- Adds CLI argument: --log-stdout
- With this arg present, any logging of a level below logging.ERROR will be sent to stdout instead of stderr
2024-12-27 14:40:05 -05:00
comfyanonymous
4b5bcd8ac4 Closer memory estimation for hunyuan dit model. 2024-12-27 07:37:00 -05:00
comfyanonymous
ceb50b2cbf Closer memory estimation for pixart models. 2024-12-27 07:30:09 -05:00
comfyanonymous
160ca08138 Use python 3.9 in launch test instead of 3.8
Fix ruff check.
2024-12-26 20:05:54 -05:00
Huazhong Ji
c4bfdba330 Support ascend npu (#5436)
* support ascend npu

Co-authored-by: YukMingLaw <lymmm2@163.com>
Co-authored-by: starmountain1997 <guozr1997@hotmail.com>
Co-authored-by: Ginray <ginray0215@gmail.com>
2024-12-26 19:36:50 -05:00
comfyanonymous
ee9547ba31 Improve temporal VAE Encode (Tiled) math. 2024-12-26 07:18:49 -05:00
comfyanonymous
19a64d6291 Cleanup some mac related code. 2024-12-25 05:32:51 -05:00
comfyanonymous
b486885e08 Disable bfloat16 on older mac. 2024-12-25 05:18:50 -05:00
comfyanonymous
0229228f3f Clean up the VAE dtypes code. 2024-12-25 04:50:34 -05:00
comfyanonymous
1ed75ab30e Update nightly pytorch instructions in readme for nvidia. 2024-12-25 03:29:03 -05:00
comfyanonymous
99a1fb6027 Make fast fp8 take a bit less peak memory. 2024-12-24 18:05:19 -05:00
comfyanonymous
73e04987f7 Prevent black images in VAE Decode (Tiled) node.
Overlap should be minimum 1 with tiling 2 for tiled temporal VAE decoding.
2024-12-24 07:36:30 -05:00
comfyanonymous
5388df784a Add temporal tiling to VAE Encode (Tiled) node. 2024-12-24 07:10:09 -05:00
Alexander Piskun
26e0ba8f8c Enable External Event Loop Integration for ComfyUI [refactor] (#6114)
* Refactor main.py to support external event loop integration

* added optional "asyncio_loop" argument to allow using existing event loop

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-24 06:38:52 -05:00
comfyanonymous
bc6dac4327 Add temporal tiling to VAE Decode (Tiled) node.
You can now do tiled VAE decoding on the temporal direction for videos.
2024-12-23 20:03:37 -05:00
Chenlei Hu
f18ebbd316 Use raw dir name to serve static web content (#6107) 2024-12-23 03:29:42 -05:00
comfyanonymous
15564688ed Add a try except block so if torch version is weird it won't crash. 2024-12-23 03:22:48 -05:00
Simon Lui
c6b9c11ef6 Add oneAPI device selector for xpu and some other changes. (#6112)
* Add oneAPI device selector and some other minor changes.

* Fix device selector variable name.

* Flip minor version check sign.

* Undo changes to README.md.
2024-12-23 03:18:32 -05:00
comfyanonymous
e44d0ac7f7 Make --novram completely offload weights.
This flag is mainly used for testing the weight offloading, it shouldn't
actually be used in practice.

Remove useless import.
2024-12-23 01:51:08 -05:00
comfyanonymous
56bc64f351 Comment out some useless code. 2024-12-22 23:51:14 -05:00
zhangp365
f7d83b72e0 fixed a bug in ldm/pixart/blocks.py (#6158) 2024-12-22 23:44:20 -05:00
comfyanonymous
80f07952d2 Fix lowvram issue with ltxv vae. 2024-12-22 23:20:17 -05:00
comfyanonymous
57f330caf9 Relax minimum ratio of weights loaded in memory on nvidia.
This should make it possible to do higher res images/longer videos by
further offloading weights to CPU memory.

Please report an issue if this slows down things on your system.
2024-12-22 03:06:37 -05:00
comfyanonymous
601ff9e3db Add that Hunyuan Video and Pixart are supported to readme.
Clean up the supported models part of the readme.
2024-12-21 11:31:39 -05:00
TechnoByte
341667c4d5 remove minimum step count for AYS (#6137)
The 10 step minimum for the AYS scheduler is pointless, it works well at lower steps, like 8 steps, or even 4 steps.

For example with LCM or DMD2.

Example here: https://i.ibb.co/56CSPMj/image.png
2024-12-21 10:05:09 -05:00
Qiacheng Li
1419dee915 Update README.md for Intel GPUs (#6069) 2024-12-20 18:04:03 -05:00
comfyanonymous
da13b6b827 Get rid of meshgrid warning. 2024-12-20 18:02:12 -05:00
comfyanonymous
c86cd58573 Remove useless code. 2024-12-20 17:50:03 -05:00
comfyanonymous
b5fe39211a Remove some useless code. 2024-12-20 17:43:50 -05:00
comfyanonymous
e946667216 Some fixes/cleanups to pixart code.
Commented out the masking related code because it is never used in this
implementation.
2024-12-20 17:10:52 -05:00
Chenlei Hu
d7969cb070 Replace print with logging (#6138)
* Replace print with logging

* nit

* nit

* nit

* nit

* nit

* nit
2024-12-20 16:24:55 -05:00
City
bddb02660c Add PixArt model support (#6055)
* PixArt initial version

* PixArt Diffusers convert logic

* pos_emb and interpolation logic

* Reduce  duplicate code

* Formatting

* Use optimized attention

* Edit empty token logic

* Basic PixArt LoRA support

* Fix aspect ratio logic

* PixArtAlpha text encode with conds

* Use same detection key logic for PixArt diffusers
2024-12-20 15:25:00 -05:00
169 changed files with 205304 additions and 158626 deletions

View File

@@ -28,7 +28,7 @@ def pull(repo, remote_name='origin', branch='master'):
if repo.index.conflicts is not None:
for conflict in repo.index.conflicts:
print('Conflicts found in:', conflict[0].path)
print('Conflicts found in:', conflict[0].path) # noqa: T201
raise AssertionError('Conflicts, ahhhhh!!')
user = repo.default_signature
@@ -49,18 +49,18 @@ repo_path = str(sys.argv[1])
repo = pygit2.Repository(repo_path)
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:
print("stashing current changes")
print("stashing current changes") # noqa: T201
repo.stash(ident)
except KeyError:
print("nothing to stash")
print("nothing to stash") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:
repo.branches.local.create(backup_branch_name, repo.head.peel())
except:
pass
print("checking out master branch")
print("checking out master branch") # noqa: T201
branch = repo.lookup_branch('master')
if branch is None:
ref = repo.lookup_reference('refs/remotes/origin/master')
@@ -72,7 +72,7 @@ else:
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes")
print("pulling latest changes") # noqa: T201
pull(repo)
if "--stable" in sys.argv:
@@ -94,7 +94,7 @@ if "--stable" in sys.argv:
if latest_tag is not None:
repo.checkout(latest_tag)
print("Done!")
print("Done!") # noqa: T201
self_update = True
if len(sys.argv) > 2:

View File

@@ -22,7 +22,7 @@ on:
description: 'Python patch version'
required: true
type: string
default: "7"
default: "8"
jobs:

View File

@@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -28,4 +28,4 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements.txt

View File

@@ -17,7 +17,7 @@ jobs:
path: "ComfyUI"
- uses: actions/setup-python@v4
with:
python-version: '3.8'
python-version: '3.9'
- name: Install requirements
run: |
python -m pip install --upgrade pip

58
.github/workflows/update-frontend.yml vendored Normal file
View File

@@ -0,0 +1,58 @@
name: Update Frontend Release
on:
workflow_dispatch:
inputs:
version:
description: "Frontend version to update to (e.g., 1.0.0)"
required: true
type: string
jobs:
update-frontend:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Checkout ComfyUI
uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install wait-for-it
# Frontend asset will be downloaded to ComfyUI/web_custom_versions/Comfy-Org_ComfyUI_frontend/{version}
- name: Start ComfyUI server
run: |
python main.py --cpu --front-end-version Comfy-Org/ComfyUI_frontend@${{ github.event.inputs.version }} 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 30
- name: Configure Git
run: |
git config --global user.name "GitHub Action"
git config --global user.email "action@github.com"
# Replace existing frontend content with the new version and remove .js.map files
# See https://github.com/Comfy-Org/ComfyUI_frontend/issues/2145 for why we remove .js.map files
- name: Update frontend content
run: |
rm -rf web/
cp -r web_custom_versions/Comfy-Org_ComfyUI_frontend/${{ github.event.inputs.version }} web/
rm web/**/*.js.map
- name: Create Pull Request
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.PR_BOT_PAT }}
commit-message: "Update frontend to v${{ github.event.inputs.version }}"
title: "Frontend Update: v${{ github.event.inputs.version }}"
body: |
Automated PR to update frontend content to version ${{ github.event.inputs.version }}
This PR was created automatically by the frontend update workflow.
branch: release-${{ github.event.inputs.version }}
base: master
labels: Frontend,dependencies

58
.github/workflows/update-version.yml vendored Normal file
View File

@@ -0,0 +1,58 @@
name: Update Version File
on:
pull_request:
paths:
- "pyproject.toml"
branches:
- master
jobs:
update-version:
runs-on: ubuntu-latest
# Don't run on fork PRs
if: github.event.pull_request.head.repo.full_name == github.repository
permissions:
pull-requests: write
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- name: Update comfyui_version.py
run: |
# Read version from pyproject.toml and update comfyui_version.py
python -c '
import tomllib
# Read version from pyproject.toml
with open("pyproject.toml", "rb") as f:
config = tomllib.load(f)
version = config["project"]["version"]
# Write version to comfyui_version.py
with open("comfyui_version.py", "w") as f:
f.write("# This file is automatically generated by the build process when version is\n")
f.write("# updated in pyproject.toml.\n")
f.write(f"__version__ = \"{version}\"\n")
'
- name: Commit changes
run: |
git config --local user.name "github-actions"
git config --local user.email "github-actions@github.com"
git fetch origin ${{ github.head_ref }}
git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
git add comfyui_version.py
git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
git push origin HEAD:${{ github.head_ref }}

View File

@@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "7"
default: "8"
# push:
# branches:
# - master

View File

@@ -7,19 +7,19 @@ on:
description: 'cuda version'
required: true
type: string
default: "124"
default: "126"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
default: "13"
python_patch:
description: 'python patch version'
required: true
type: string
default: "4"
default: "1"
# push:
# branches:
# - master

View File

@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "7"
default: "8"
# push:
# branches:
# - master

View File

@@ -17,7 +17,7 @@
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink

View File

@@ -38,10 +38,21 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Image Models
- SD1.x, SD2.x,
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
- Pixart Alpha and Sigma
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@@ -61,9 +72,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.
@@ -149,6 +157,30 @@ This is the command to install the nightly with ROCm 6.2 which might have some p
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
### Intel GPUs (Windows and Linux)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch nightly, use the following command:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
2. Launch ComfyUI by running `python main.py`
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
```
conda install libuv
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
```
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
### NVIDIA
Nvidia users should install stable pytorch using this command:
@@ -157,7 +189,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
#### Troubleshooting
@@ -177,17 +209,6 @@ After this you should have everything installed and can proceed to running Comfy
### Others:
#### Intel GPUs
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
#### Apple Mac silicon
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
@@ -203,6 +224,16 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
# Running
```python main.py```
@@ -308,4 +339,3 @@ This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy
### Which GPU should I buy for this?
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)

View File

@@ -10,4 +10,4 @@ class FileService:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
return self.file_system_ops.walk_directory(directory_path)
return self.file_system_ops.walk_directory(directory_path)

View File

@@ -25,10 +25,10 @@ class TerminalService:
def update_size(self):
columns, lines = self.get_terminal_size()
changed = False
if columns != self.cols:
self.cols = columns
changed = True
changed = True
if lines != self.rows:
self.rows = lines
@@ -48,9 +48,9 @@ class TerminalService:
def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return
new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected

View File

@@ -39,4 +39,4 @@ class FileSystemOperations:
"path": relative_path,
"type": "directory"
})
return file_list
return file_list

View File

@@ -1,6 +1,7 @@
import os
import json
from aiohttp import web
import logging
class AppSettings():
@@ -11,8 +12,12 @@ class AppSettings():
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
if os.path.isfile(file):
with open(file) as f:
return json.load(f)
try:
with open(file) as f:
return json.load(f)
except:
logging.error(f"The user settings file is corrupted: {file}")
return {}
else:
return {}
@@ -51,4 +56,4 @@ class AppSettings():
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)
return web.Response(status=200)

View File

@@ -0,0 +1,34 @@
from __future__ import annotations
import os
import folder_paths
import glob
from aiohttp import web
class CustomNodeManager:
"""
Placeholder to refactor the custom node management features from ComfyUI-Manager.
Currently it only contains the custom workflow templates feature.
"""
def add_routes(self, routes, webapp, loadedModules):
@routes.get("/workflow_templates")
async def get_workflow_templates(request):
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
files = [
file
for folder in folder_paths.get_folder_paths("custom_nodes")
for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json'))
]
workflow_templates_dict = {} # custom_nodes folder name -> example workflow names
for file in files:
custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file)))
workflow_name = os.path.splitext(os.path.basename(file))[0]
workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name)
return web.json_response(workflow_templates_dict)
# Serve workflow templates from custom nodes.
for module_name, module_dir in loadedModules:
workflows_dir = os.path.join(module_dir, 'example_workflows')
if os.path.exists(workflows_dir):
webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)])

View File

@@ -51,7 +51,7 @@ def on_flush(callback):
if stderr_interceptor is not None:
stderr_interceptor.on_flush(callback)
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
global logs
if logs:
return
@@ -70,4 +70,15 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
if use_stdout:
# Only errors and critical to stderr
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
# Lesser to stdout
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
logger.addHandler(stdout_handler)
logger.addHandler(stream_handler)

View File

@@ -177,7 +177,7 @@ class ModelFileManager:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):

View File

@@ -38,8 +38,8 @@ class UserManager():
if not os.path.exists(user_directory):
os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(self.get_users_file()):

View File

@@ -160,7 +160,6 @@ class ControlNet(nn.Module):
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None

View File

@@ -84,7 +84,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
@@ -121,7 +122,7 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
@@ -140,6 +141,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"

View File

@@ -5,7 +5,7 @@ This module provides type hinting and concrete convenience types for node develo
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
```python
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
class ExampleNode(ComfyNodeABC):
@classmethod

View File

@@ -1,12 +1,12 @@
from comfy_types import IO, ComfyNodeABC, InputTypeDict
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
class ExampleNode(ComfyNodeABC):
"""An example node that just adds 1 to an input integer.
* Requires an IDE configured with analysis paths etc to be worth looking at.
* Not intended for use in ComfyUI.
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
* This node is intended as an example for developers only.
"""
DESCRIPTION = cleandoc(__doc__)

View File

@@ -120,7 +120,7 @@ class ControlBase:
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:

View File

@@ -2,6 +2,7 @@
import torch
import math
import logging
from tqdm.auto import trange
@@ -79,7 +80,7 @@ class NoiseScheduleVP:
'linear' or 'cosine' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
@@ -207,7 +208,7 @@ def model_wrapper(
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
@@ -225,7 +226,7 @@ def model_wrapper(
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
``
The input `classifier_fn` has the following format:
``
@@ -239,12 +240,12 @@ def model_wrapper(
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
@@ -253,7 +254,7 @@ def model_wrapper(
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
@@ -358,7 +359,7 @@ class UniPC:
max_val=1.,
variant='bh1',
):
"""Construct a UniPC.
"""Construct a UniPC.
We support both data_prediction and noise_prediction.
"""
@@ -371,7 +372,7 @@ class UniPC:
def dynamic_thresholding_fn(self, x0, t=None):
"""
The dynamic thresholding method.
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
@@ -403,7 +404,7 @@ class UniPC:
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
@@ -460,7 +461,7 @@ class UniPC:
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
@@ -474,7 +475,7 @@ class UniPC:
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self.noise_schedule
assert order <= len(model_prev_list)
@@ -509,7 +510,7 @@ class UniPC:
col = torch.ones_like(rks)
for k in range(1, K + 1):
C.append(col)
col = col * rks / (k + 1)
col = col * rks / (k + 1)
C = torch.stack(C, dim=1)
if len(D1s) > 0:
@@ -518,7 +519,6 @@ class UniPC:
A_p = C_inv_p
if use_corrector:
print('using corrector')
C_inv = torch.linalg.inv(C)
A_c = C_inv
@@ -621,12 +621,12 @@ class UniPC:
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= (i + 1)
h_phi_k = h_phi_k / hh - 1 / factorial_i
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=x.device)
@@ -870,4 +870,4 @@ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=F
return x
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')

View File

@@ -5,6 +5,7 @@ import math
import torch
import numpy as np
import itertools
import logging
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher, PatcherInjection
@@ -15,130 +16,171 @@ import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
# #######################################################################################################
# Hooks explanation
# -------------------
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
# make explicit special cases like it does for ControlNet and GLIGEN.
#
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
# that should run special code when a 'marked' cond is used in sampling.
# #######################################################################################################
class EnumHookMode(enum.Enum):
'''
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
MinVram: No caching will occur for any operations related to hooks.
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
'''
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
'''
Hook types, each of which has different expected behavior.
'''
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
AdditionalModels = "add_models"
TransformerOptions = "transformer_options"
Injections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class EnumHookScope(enum.Enum):
'''
Determines if hook should be limited in its influence over sampling.
AllConditioning: hook will affect all conds used in sampling.
HookedOnly: hook will only affect the conds it was attached to.
'''
AllConditioning = "all_conditioning"
HookedOnly = "hooked_only"
class _HookRef:
pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
'''Example for how custom_should_register function can look like.'''
return True
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
'''Creates base dictionary for use with Hooks' target param.'''
d = {}
if target is not None:
d['target'] = target
d.update(kwargs)
return d
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
self.hook_type = hook_type
'''Enum identifying the general class of this hook.'''
self.hook_ref = hook_ref if hook_ref else _HookRef()
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
self.hook_id = hook_id
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: BaseModel):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: Hook = subtype()
def clone(self):
c: Hook = self.__class__()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.hook_scope = self.hook_scope
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
return self.custom_should_register(self, model, model_options, target_dict, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
def __eq__(self, other: Hook):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
'''
Hook responsible for tracking weights to be applied to some model/clip.
Note, value of hook_scope is ignored and is treated as HookedOnly.
'''
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:
target = target_dict.get('target', None)
if target == EnumWeightTarget.Clip:
strength = self._strength_clip
else:
strength = self._strength_model
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
if target == EnumWeightTarget.Clip:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
if target == EnumWeightTarget.Clip:
weights = self.weights_clip
else:
weights = self.weights
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
registered.add(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WeightHook = super().clone(subtype)
def clone(self):
c: WeightHook = super().clone()
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
@@ -146,127 +188,158 @@ class WeightHook(Hook):
c._strength_clip = self._strength_clip
return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook):
def __init__(self):
def __init__(self, object_patches: dict[str]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
self.object_patches = object_patches
self.hook_scope = hook_scope
def clone(self):
c: ObjectPatchHook = super().clone()
c.object_patches = self.object_patches
return c
# TODO: add functionality
class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
class AdditionalModelsHook(Hook):
'''
Hook responsible for telling model management any additional models that should be loaded.
Note, value of hook_scope is ignored and is treated as AllConditioning.
'''
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
super().__init__(hook_type=EnumHookType.AdditionalModels)
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
def clone(self):
c: AdditionalModelsHook = super().clone()
c.models = self.models.copy() if self.models else self.models
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
registered.add(self)
return True
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
class TransformerOptionsHook(Hook):
'''
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
'''
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = transformers_dict
self.hook_scope = hook_scope
self._skip_adding = False
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
def clone(self):
c: TransformerOptionsHook = super().clone()
c.transformers_dict = self.transformers_dict
c._skip_adding = self._skip_adding
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
self._skip_adding = False
if self.hook_scope == EnumHookScope.AllConditioning:
add_model_options = {"transformer_options": self.transformers_dict,
"to_load_options": self.transformers_dict}
# skip_adding if included in AllConditioning to avoid double loading
self._skip_adding = True
else:
add_model_options = {"to_load_options": self.transformers_dict}
registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
if not self._skip_adding:
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
WrapperHook = TransformerOptionsHook
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
class InjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.Injections)
self.key = key
self.injections = injections
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
self.hook_scope = hook_scope
def clone(self):
c: InjectionsHook = super().clone()
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
class HookGroup:
'''
Stores groups of hooks, and allows them to be queried by type.
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
always use the provided functions on HookGroup.
'''
def __init__(self):
self.hooks: list[Hook] = []
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
def __len__(self):
return len(self.hooks)
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
def remove(self, hook: Hook):
if hook in self.hooks:
self.hooks.remove(hook)
self._hook_dict[hook.hook_type].remove(hook)
def get_type(self, hook_type: EnumHookType):
return self._hook_dict.get(hook_type, [])
def contains(self, hook: Hook):
return hook in self.hooks
def is_subset_of(self, other: HookGroup):
self_hooks = set(self.hooks)
other_hooks = set(other.hooks)
return self_hooks.issubset(other_hooks)
def new_with_common_hooks(self, other: HookGroup):
c = HookGroup()
for hook in self.hooks:
if other.contains(hook):
c.add(hook.clone())
return c
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: 'HookGroup'):
def clone_and_combine(self, other: HookGroup):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
@@ -274,36 +347,29 @@ class HookGroup:
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks:
# only care about WeightHooks, for now
if hook.hook_type == EnumHookType.Weight:
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
# only care about WeightHooks, for now
for hook in self.get_type(EnumHookType.Weight):
hook: WeightHook
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values():
@@ -335,7 +401,7 @@ class HookGroup:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
@@ -364,10 +430,16 @@ class HookKeyframe:
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
if self.start_t > max_sigma:
return 0
return self.guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
@@ -394,7 +466,7 @@ class HookKeyframeGroup:
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
@@ -406,33 +478,40 @@ class HookKeyframeGroup:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_guarantee_steps(self):
for kf in self.keyframes:
if kf.guarantee_steps > 0:
return True
return False
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool:
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
max_sigma = torch.max(transformer_options["sample_sigmas"])
prev_index = self._current_index
prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
@@ -445,7 +524,7 @@ class HookKeyframeGroup:
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
@@ -508,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
sorted_list.extend(object_list)
return sorted_list
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
# if no hooks or is not a ModelPatcher for sampling, return empty dict
if hooks is None or model.is_clip:
return {}
if transformer_options is None:
transformer_options = {}
for hook in hooks.get_type(EnumHookType.TransformerOptions):
hook: TransformerOptionsHook
hook.on_apply_hooks(model, transformer_options)
return transformer_options
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
@@ -534,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
@@ -546,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
@@ -564,7 +654,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
@@ -575,7 +665,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print(f"NOT LOADED {x}")
logging.warning(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
@@ -598,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
if cache is None:
cache = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
_combine_hooks_from_values(n[1], values, cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
@@ -650,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
cache = {}
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
@@ -664,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
cache = {}
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
@@ -678,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
cache = {}
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present

View File

@@ -70,8 +70,14 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
return sigma_down, sigma_up
def default_noise_sampler(x):
return lambda sigma, sigma_next: torch.randn_like(x)
def default_noise_sampler(x, seed=None):
if seed is not None:
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
else:
generator = None
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
class BatchedBrownianTree:
@@ -168,7 +174,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -189,7 +196,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -290,7 +298,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -318,7 +327,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -465,7 +475,7 @@ class DPMSolver(nn.Module):
return x_3, eps_cache
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
if not t_end > t_start and eta:
raise ValueError('eta must be 0 for reverse sampling')
@@ -504,7 +514,7 @@ class DPMSolver(nn.Module):
return x
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
forward = t_end > t_start
@@ -591,7 +601,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
@@ -625,7 +636,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
@@ -882,7 +894,8 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
@@ -902,7 +915,8 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -1153,7 +1167,8 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
@@ -1179,7 +1194,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
@@ -1230,7 +1246,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
@@ -1249,3 +1265,74 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x
@torch.no_grad()
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False):
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
old_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
if cfg_pp:
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
if sigmas[i + 1] == 0 or old_denoised is None:
# Euler method
if cfg_pp:
d = to_d(x, sigma_hat, uncond_denoised)
x = denoised + d * sigmas[i + 1]
else:
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
h = t_next - t
c2 = (t_prev - t) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
if cfg_pp:
x = x + (denoised - uncond_denoised)
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
old_denoised = denoised
return x
@torch.no_grad()
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False)
@torch.no_grad()
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)

View File

@@ -3,6 +3,7 @@ import torch
class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_dimensions = 2
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None
@@ -143,6 +144,7 @@ class SD3(LatentFormat):
class StableAudio1(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class Flux(SD3):
latent_channels = 16
@@ -178,6 +180,7 @@ class Flux(SD3):
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
def __init__(self):
self.scale_factor = 1.0
@@ -219,6 +222,8 @@ class Mochi(LatentFormat):
class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
@@ -355,6 +360,7 @@ class LTXV(LatentFormat):
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
@@ -376,3 +382,28 @@ class HunyuanVideo(LatentFormat):
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[ 0.1817, 0.2284, 0.2423],
[-0.0586, -0.0862, -0.3108],
[-0.4703, -0.4255, -0.3995],
[ 0.0803, 0.1963, 0.1001],
[-0.0820, -0.1050, 0.0400],
[ 0.2511, 0.3098, 0.2787],
[-0.1830, -0.2117, -0.0040],
[-0.0621, -0.2187, -0.0939],
[ 0.3619, 0.1082, 0.1455],
[ 0.3164, 0.3922, 0.2575],
[ 0.1152, 0.0231, -0.0462],
[-0.1434, -0.3609, -0.3665],
[ 0.0635, 0.1471, 0.1680],
[-0.3635, -0.1963, -0.3248],
[-0.1865, 0.0365, 0.2346],
[ 0.0447, 0.0994, 0.0881]
]
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]

View File

@@ -381,7 +381,6 @@ class MMDiT(nn.Module):
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
self.h_max, self.w_max = target_dim
print("PE extended to", target_dim)
def pe_selection_index_based_on_dim(self, h, w):
h_p, w_p = h // self.patch_size, w // self.patch_size

View File

@@ -138,7 +138,7 @@ class StageB(nn.Module):
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
@@ -148,7 +148,7 @@ class StageB(nn.Module):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)

View File

@@ -142,7 +142,7 @@ class StageC(nn.Module):
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
@@ -152,7 +152,7 @@ class StageC(nn.Module):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)

808
comfy/ldm/cosmos/blocks.py Normal file
View File

@@ -0,0 +1,808 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import logging
import numpy as np
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
def get_normalization(name: str, channels: int, weight_args={}):
if name == "I":
return nn.Identity()
elif name == "R":
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else:
raise ValueError(f"Normalization {name} not found")
class BaseAttentionOp(nn.Module):
def __init__(self):
super().__init__()
class Attention(nn.Module):
"""
Generalized attention impl.
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
If `context_dim` is None, self-attention is assumed.
Parameters:
query_dim (int): Dimension of each query vector.
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
heads (int, optional): Number of attention heads. Defaults to 8.
dim_head (int, optional): Dimension of each head. Defaults to 64.
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
Defaults to "SSI".
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
Defaults to 'per_head'. Only support 'per_head'.
Examples:
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
>>> query = torch.randn(10, 128) # Batch size of 10
>>> context = torch.randn(10, 256) # Batch size of 10
>>> output = attn(query, context) # Perform the attention operation
Note:
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
"""
def __init__(
self,
query_dim: int,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
attn_op: Optional[BaseAttentionOp] = None,
qkv_bias: bool = False,
out_bias: bool = False,
qkv_norm: str = "SSI",
qkv_norm_mode: str = "per_head",
backend: str = "transformer_engine",
qkv_format: str = "bshd",
weight_args={},
operations=None,
) -> None:
super().__init__()
self.is_selfattn = context_dim is None # self attention
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.heads = heads
self.dim_head = dim_head
self.qkv_norm_mode = qkv_norm_mode
self.qkv_format = qkv_format
if self.qkv_norm_mode == "per_head":
norm_dim = dim_head
else:
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
self.backend = backend
self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[0], norm_dim),
)
self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[1], norm_dim),
)
self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[2], norm_dim),
)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
nn.Dropout(dropout),
)
def cal_qkv(
self, x, context=None, mask=None, rope_emb=None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
del kwargs
"""
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
Before 07/24/2024, these modules normalize across all heads.
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
we support to normalize per head.
To keep the checkpoint copatibility with the previous code,
we keep the nn.Sequential but call the projection and the normalization layers separately.
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
"""
if self.qkv_norm_mode == "per_head":
q = self.to_q[0](x)
context = x if context is None else context
k = self.to_k[0](context)
v = self.to_v[0](context)
q, k, v = map(
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
(q, k, v),
)
else:
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
q = self.to_q[1](q)
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
# apply_rotary_pos_emb inlined
q_shape = q.shape
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
# apply_rotary_pos_emb inlined
k_shape = k.shape
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
return q, k, v
def forward(
self,
x,
context=None,
mask=None,
rope_emb=None,
**kwargs,
):
"""
Args:
x (Tensor): The query tensor of shape [B, Mq, K]
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
class FeedForward(nn.Module):
"""
Transformer FFN with optional gating
Parameters:
d_model (int): Dimensionality of input features.
d_ff (int): Dimensionality of the hidden layer.
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
activation (callable, optional): The activation function applied after the first linear layer.
Defaults to nn.ReLU().
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
Defaults to False.
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
Example:
>>> ff = FeedForward(d_model=512, d_ff=2048)
>>> x = torch.randn(64, 10, 512) # Example input tensor
>>> output = ff(x)
>>> print(output.shape) # Expected shape: (64, 10, 512)
"""
def __init__(
self,
d_model: int,
d_ff: int,
dropout: float = 0.1,
activation=nn.ReLU(),
is_gated: bool = False,
bias: bool = False,
weight_args={},
operations=None,
) -> None:
super().__init__()
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
self.dropout = nn.Dropout(dropout)
self.activation = activation
self.is_gated = is_gated
if is_gated:
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
def forward(self, x: torch.Tensor):
g = self.activation(self.layer1(x))
if self.is_gated:
x = g * self.linear_gate(x)
else:
x = g
assert self.dropout.p == 0.0, "we skip dropout"
return self.layer2(x)
class GPT2FeedForward(FeedForward):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
super().__init__(
d_model=d_model,
d_ff=d_ff,
dropout=dropout,
activation=nn.GELU(),
is_gated=False,
bias=bias,
weight_args=weight_args,
operations=operations,
)
def forward(self, x: torch.Tensor):
assert self.dropout.p == 0.0, "we skip dropout"
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class Timesteps(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
super().__init__()
logging.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
else:
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_3D = emb
emb_B_D = sample
else:
emb_B_D = emb
adaln_lora_B_3D = None
return emb_B_D, adaln_lora_B_3D
class FourierFeatures(nn.Module):
"""
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
[B] -> [B, D]
Parameters:
num_channels (int): The number of Fourier features to generate.
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
the variance of the features. Defaults to False.
Example:
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
>>> x = torch.randn(10, 256) # Example input tensor
>>> output = layer(x)
>>> print(output.shape) # Expected shape: (10, 256)
"""
def __init__(self, num_channels, bandwidth=1, normalize=False):
super().__init__()
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
self.gain = np.sqrt(2) if normalize else 1
def forward(self, x, gain: float = 1.0):
"""
Apply the Fourier feature transformation to the input tensor.
Args:
x (torch.Tensor): The input tensor.
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
Returns:
torch.Tensor: The transformed tensor, with Fourier features applied.
"""
in_dtype = x.dtype
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
x = x.cos().mul(self.gain * gain).to(in_dtype)
return x
class PatchEmbed(nn.Module):
"""
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
making it suitable for video and image processing tasks. It supports dividing the input into patches
and embedding each patch into a vector of size `out_channels`.
Parameters:
- spatial_patch_size (int): The size of each spatial patch.
- temporal_patch_size (int): The size of each temporal patch.
- in_channels (int): Number of input channels. Default: 3.
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
"""
def __init__(
self,
spatial_patch_size,
temporal_patch_size,
in_channels=3,
out_channels=768,
bias=True,
weight_args={},
operations=None,
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
operations.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
),
)
self.out = nn.Identity()
def forward(self, x):
"""
Forward pass of the PatchEmbed module.
Parameters:
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
B is the batch size,
C is the number of channels,
T is the temporal dimension,
H is the height, and
W is the width of the input.
Returns:
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
"""
assert x.dim() == 5
_, _, T, H, W = x.shape
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return self.out(x)
class FinalLayer(nn.Module):
"""
The final layer of video DiT.
"""
def __init__(
self,
hidden_size,
spatial_patch_size,
temporal_patch_size,
out_channels,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
)
def forward(
self,
x_BT_HW_D,
emb_B_D,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_3D is not None
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
2, dim=1
)
else:
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
B = emb_B_D.shape[0]
T = x_BT_HW_D.shape[0] // B
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
x_BT_HW_D = self.linear(x_BT_HW_D)
return x_BT_HW_D
class VideoAttn(nn.Module):
"""
Implements video attention with optional cross-attention capabilities.
This module processes video features while maintaining their spatio-temporal structure. It can perform
self-attention within the video features or cross-attention with external context features.
Parameters:
x_dim (int): Dimension of input feature vectors
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
num_heads (int): Number of attention heads
bias (bool): Whether to include bias in attention projections. Default: False
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
Input shape:
- x: (T, H, W, B, D) video features
- context (optional): (M, B, D) context features for cross-attention
where:
T: temporal dimension
H: height
W: width
B: batch size
D: feature dimension
M: context sequence length
"""
def __init__(
self,
x_dim: int,
context_dim: Optional[int],
num_heads: int,
bias: bool = False,
qkv_norm_mode: str = "per_head",
x_format: str = "BTHWD",
weight_args={},
operations=None,
) -> None:
super().__init__()
self.x_format = x_format
self.attn = Attention(
x_dim,
context_dim,
num_heads,
x_dim // num_heads,
qkv_bias=bias,
qkv_norm="RRI",
out_bias=bias,
qkv_norm_mode=qkv_norm_mode,
qkv_format="sbhd",
weight_args=weight_args,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for video attention.
Args:
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
where M is the sequence length of the context.
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
rope_emb_L_1_1_D (Optional[Tensor]):
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
Returns:
Tensor: The output tensor with applied attention, maintaining the input shape.
"""
x_T_H_W_B_D = x
context_M_B_D = context
T, H, W, B, D = x_T_H_W_B_D.shape
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
x_THW_B_D = self.attn(
x_THW_B_D,
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
def adaln_norm_state(norm_state, x, scale, shift):
normalized = norm_state(x)
return normalized * (1 + scale) + shift
class DITBuildingBlock(nn.Module):
"""
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
attention and MLP operations with adaptive layer normalization.
Parameters:
block_type (str): Type of block - one of:
- "cross_attn"/"ca": Cross-attention
- "full_attn"/"fa": Full self-attention
- "mlp"/"ff": MLP/feedforward block
x_dim (int): Dimension of input features
context_dim (Optional[int]): Dimension of context features for cross-attention
num_heads (int): Number of attention heads
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
bias (bool): Whether to use bias in layers. Default: False
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
x_format (str): Input tensor format. Default: "BTHWD"
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
"""
def __init__(
self,
block_type: str,
x_dim: int,
context_dim: Optional[int],
num_heads: int,
mlp_ratio: float = 4.0,
bias: bool = False,
mlp_dropout: float = 0.0,
qkv_norm_mode: str = "per_head",
x_format: str = "BTHWD",
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None
) -> None:
block_type = block_type.lower()
super().__init__()
self.x_format = x_format
if block_type in ["cross_attn", "ca"]:
self.block = VideoAttn(
x_dim,
context_dim,
num_heads,
bias=bias,
qkv_norm_mode=qkv_norm_mode,
x_format=self.x_format,
weight_args=weight_args,
operations=operations,
)
elif block_type in ["full_attn", "fa"]:
self.block = VideoAttn(
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
)
elif block_type in ["mlp", "ff"]:
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
else:
raise ValueError(f"Unknown block type: {block_type}")
self.block_type = block_type
self.use_adaln_lora = use_adaln_lora
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.n_adaln_chunks = 3
if use_adaln_lora:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
)
else:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
def forward(
self,
x: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
Args:
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
crossattn_emb (Tensor): Tensor for cross-attention blocks.
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
rope_emb_L_1_1_D (Optional[Tensor]):
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
Returns:
Tensor: The output tensor after processing through the configured block and adaptive normalization.
"""
if self.use_adaln_lora:
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
self.n_adaln_chunks, dim=1
)
else:
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
)
if self.block_type in ["mlp", "ff"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
)
elif self.block_type in ["full_attn", "fa"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
return x
class GeneralDITTransformerBlock(nn.Module):
"""
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
Each block in the sequence is specified by a block configuration string.
Parameters:
x_dim (int): Dimension of input features
context_dim (int): Dimension of context features for cross-attention blocks
num_heads (int): Number of attention heads
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
full-attention, then MLP)
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
x_format (str): Input tensor format. Default: "BTHWD"
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
The block_config string uses "-" to separate block types:
- "ca"/"cross_attn": Cross-attention block
- "fa"/"full_attn": Full self-attention block
- "mlp"/"ff": MLP/feedforward block
Example:
block_config = "ca-fa-mlp" creates a sequence of:
1. Cross-attention block
2. Full self-attention block
3. MLP block
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
block_config: str,
mlp_ratio: float = 4.0,
x_format: str = "BTHWD",
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None
):
super().__init__()
self.blocks = nn.ModuleList()
self.x_format = x_format
for block_type in block_config.split("-"):
self.blocks.append(
DITBuildingBlock(
block_type,
x_dim,
context_dim,
num_heads,
mlp_ratio,
x_format=self.x_format,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
)
def forward(
self,
x: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for block in self.blocks:
x = block(
x,
emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,377 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The patcher and unpatcher implementation for 2D and 3D data.
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
One on the rows and one on the columns.
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
as we need to support downsampling for more than 2x.
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
"""
import torch
import torch.nn.functional as F
from einops import rearrange
_WAVELETS = {
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
"rearrange": torch.tensor([1.0, 1.0]),
}
_PERSISTENT = False
class Patcher(torch.nn.Module):
"""A module to convert image tensors into patches using torch operations.
The main difference from `class Patching` is that this module implements
all operations using torch, rather than python or numpy, for efficiency purpose.
It's bit-wise identical to the Patching module outputs, with the added
benefit of being torch.jit scriptable.
"""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer(
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
)
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=_PERSISTENT,
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
if self.patch_method == "haar":
return self._haar(x)
elif self.patch_method == "rearrange":
return self._arrange(x)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
def _dwt(self, x, mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
if rescale:
out = out / 2
return out
def _haar(self, x):
for _ in self.range:
x = self._dwt(x, rescale=True)
return x
def _arrange(self, x):
x = rearrange(
x,
"b c (h p1) (w p2) -> b (c p1 p2) h w",
p1=self.patch_size,
p2=self.patch_size,
).contiguous()
return x
class Patcher3D(Patcher):
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__(patch_method=patch_method, patch_size=patch_size)
self.register_buffer(
"patch_size_buffer",
patch_size * torch.ones([1], dtype=torch.int32),
persistent=_PERSISTENT,
)
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
# Handles temporal axis.
x = F.pad(
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
).to(dtype)
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
# Handles spatial axes.
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
if rescale:
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
return out
def _haar(self, x):
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
for _ in self.range:
x = self._dwt(x, "haar", rescale=True)
return x
def _arrange(self, x):
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
x = rearrange(
x,
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
p1=self.patch_size,
p2=self.patch_size,
p3=self.patch_size,
).contiguous()
return x
class UnPatcher(torch.nn.Module):
"""A module to convert patches into image tensorsusing torch operations.
The main difference from `class Unpatching` is that this module implements
all operations using torch, rather than python or numpy, for efficiency purpose.
It's bit-wise identical to the Unpatching module outputs, with the added
benefit of being torch.jit scriptable.
"""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer(
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
)
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=_PERSISTENT,
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
if self.patch_method == "haar":
return self._ihaar(x)
elif self.patch_method == "rearrange":
return self._iarrange(x)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1] // 4
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
# Inverse transform.
yl = torch.nn.functional.conv_transpose2d(
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yl += torch.nn.functional.conv_transpose2d(
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh = torch.nn.functional.conv_transpose2d(
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh += torch.nn.functional.conv_transpose2d(
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
y = torch.nn.functional.conv_transpose2d(
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
y += torch.nn.functional.conv_transpose2d(
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
if rescale:
y = y * 2
return y
def _ihaar(self, x):
for _ in self.range:
x = self._idwt(x, "haar", rescale=True)
return x
def _iarrange(self, x):
x = rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
)
return x
class UnPatcher3D(UnPatcher):
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__(patch_method=patch_method, patch_size=patch_size)
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hl = hl.to(dtype=dtype)
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
del x
# Height height transposed convolutions.
xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlll
xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xllh
xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhl
xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhh
xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhll
xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhlh
xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhl
xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhh
# Handles width transposed convolutions.
xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xll
xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xlh
xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhl
xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhh
# Handles time axis transposed convolutions.
x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
del xl
x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
if rescale:
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
return x
def _ihaar(self, x):
for _ in self.range:
x = self._idwt(x, "haar", rescale=True)
x = x[:, :, self.patch_size - 1 :, ...]
return x
def _iarrange(self, x):
x = rearrange(
x,
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
p1=self.patch_size,
p2=self.patch_size,
p3=self.patch_size,
)
x = x[:, :, self.patch_size - 1 :, ...]
return x

View File

@@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for the networks module."""
from typing import Any
import torch
from einops import rearrange
import comfy.ops
ops = comfy.ops.disable_weight_init
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size = x.shape[0]
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size, height = x.shape[0], x.shape[-2]
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
def cast_tuple(t: Any, length: int = 1) -> Any:
return t if isinstance(t, tuple) else ((t,) * length)
def replication_pad(x):
return torch.cat([x[:, :, :1, ...], x], dim=2)
def divisible_by(num: int, den: int) -> bool:
return (num % den) == 0
def is_odd(n: int) -> bool:
return not divisible_by(n, 2)
def nonlinearity(x):
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class CausalNormalize(torch.nn.Module):
def __init__(self, in_channels, num_groups=1):
super().__init__()
self.norm = ops.GroupNorm(
num_groups=num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True,
)
self.num_groups = num_groups
def forward(self, x):
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
# All new models should use num_groups=1, otherwise causality is not guaranteed.
if self.num_groups == 1:
x, batch_size = time2batch(x)
return batch2time(self.norm(x), batch_size)
return self.norm(x)
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
def log(t, eps=1e-5):
return t.clamp(min=eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)

514
comfy/ldm/cosmos/model.py Normal file
View File

@@ -0,0 +1,514 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
"""
from typing import Optional, Tuple
import torch
from einops import rearrange
from torch import nn
from torchvision import transforms
from enum import Enum
import logging
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
PatchEmbed,
TimestepEmbedding,
Timesteps,
)
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
class DataType(Enum):
IMAGE = "image"
VIDEO = "video"
class GeneralDIT(nn.Module):
"""
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
Args:
max_img_h (int): Maximum height of the input images.
max_img_w (int): Maximum width of the input images.
max_frames (int): Maximum number of frames in the video sequence.
in_channels (int): Number of input channels (e.g., RGB channels for color images).
out_channels (int): Number of output channels.
patch_spatial (tuple): Spatial resolution of patches for input processing.
patch_temporal (int): Temporal resolution of patches for input processing.
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
block_config (str): Configuration of the transformer block. See Notes for supported block types.
model_channels (int): Base number of channels used throughout the model.
num_blocks (int): Number of transformer blocks.
num_heads (int): Number of heads in the multi-head attention layers.
mlp_ratio (float): Expansion ratio for MLP blocks.
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
pos_emb_cls (str): Type of positional embeddings.
pos_emb_learnable (bool): Whether positional embeddings are learnable.
pos_emb_interpolation (str): Method for interpolating positional embeddings.
affline_emb_norm (bool): Whether to normalize affine embeddings.
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
Notes:
Supported block types in block_config:
* cross_attn, ca: Cross attention
* full_attn: Full attention on all flattened tokens
* mlp, ff: Feed forward block
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: tuple,
patch_temporal: int,
concat_padding_mask: bool = True,
# attention settings
block_config: str = "FA-CA-MLP",
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
block_x_format: str = "BTHWD",
# cross attention settings
crossattn_emb_channels: int = 1024,
use_cross_attn_mask: bool = False,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_per_block_abs_pos_emb_type: str = "sincos",
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.use_cross_attn_mask = use_cross_attn_mask
self.concat_padding_mask = concat_padding_mask
# positional embedding settings
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.affline_emb_norm = affline_emb_norm
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.dtype = dtype
weight_args = {"device": device, "dtype": dtype}
in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=patch_spatial,
temporal_patch_size=patch_temporal,
in_channels=in_channels,
out_channels=model_channels,
bias=False,
weight_args=weight_args,
operations=operations,
)
self.build_pos_embed(device=device, dtype=dtype)
self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.ModuleList(
[Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
)
self.blocks = nn.ModuleDict()
for idx in range(num_blocks):
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
block_config=block_config,
mlp_ratio=mlp_ratio,
x_format=self.block_x_format,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer")
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
else:
self.affline_norm = nn.Identity()
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
def build_pos_embed(self, device=None, dtype=None):
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
device=device,
)
self.pos_embedder = cls_type(
**kwargs,
)
if self.extra_per_block_abs_pos_emb:
assert self.extra_per_block_abs_pos_emb_type in [
"learnable",
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs,
)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
Args:
x_B_C_T_H_W (torch.Tensor): video
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
If None, a default value (`self.base_fps`) will be used.
padding_mask (Optional[torch.Tensor]): current it is not used
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
- An optional positional embedding tensor, returned only if the positional embedding class
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
Notes:
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
the `self.pos_embedder` with the shape [T, H, W].
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
`self.pos_embedder` with the fps tensor.
- Otherwise, the positional embeddings are generated without considering fps.
"""
if self.concat_padding_mask:
if padding_mask is not None:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
else:
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
if "fps_aware" in self.pos_emb_cls:
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
else:
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
return x_B_T_H_W_D, None, extra_pos_emb
def decoder_head(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
crossattn_mask: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
del crossattn_emb, crossattn_mask
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
# This is to ensure x_BT_HW_D has the correct shape because
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
x_BT_HW_D = x_BT_HW_D.view(
B * T_before_patchify // self.patch_temporal,
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
-1,
)
x_B_D_T_H_W = rearrange(
x_BT_HW_D,
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
H=H_before_patchify // self.patch_spatial,
W=W_before_patchify // self.patch_spatial,
t=self.patch_temporal,
B=B,
)
return x_B_D_T_H_W
def forward_before_blocks(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
crossattn_mask: (B, N) tensor of cross-attention masks
"""
del kwargs
assert isinstance(
data_type, DataType
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
original_shape = x.shape
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
x,
fps=fps,
padding_mask=padding_mask,
latent_condition=latent_condition,
latent_condition_sigma=latent_condition_sigma,
)
# logging affline scale information
affline_scale_log_info = {}
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
affline_emb_B_D = timesteps_B_D
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
if scalar_feature is not None:
raise NotImplementedError("Scalar feature is not implemented yet.")
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
if self.use_cross_attn_mask:
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
else:
crossattn_mask = None
if self.blocks["block0"].x_format == "THWBD":
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
)
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
if crossattn_mask:
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
elif self.blocks["block0"].x_format == "BTHWD":
x = x_B_T_H_W_D
else:
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
output = {
"x": x,
"affline_emb_B_D": affline_emb_B_D,
"crossattn_emb": crossattn_emb,
"crossattn_mask": crossattn_mask,
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
"adaln_lora_B_3D": adaln_lora_B_3D,
"original_shape": original_shape,
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
}
return output
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
crossattn_mask: (B, N) tensor of cross-attention masks
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
we need forward_before_blocks pass to the forward_before_blocks function.
"""
crossattn_emb = context
crossattn_mask = attention_mask
inputs = self.forward_before_blocks(
x=x,
timesteps=timesteps,
crossattn_emb=crossattn_emb,
crossattn_mask=crossattn_mask,
fps=fps,
image_size=image_size,
padding_mask=padding_mask,
scalar_feature=scalar_feature,
data_type=data_type,
latent_condition=latent_condition,
latent_condition_sigma=latent_condition_sigma,
condition_video_augment_sigma=condition_video_augment_sigma,
**kwargs,
)
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
inputs["x"],
inputs["affline_emb_B_D"],
inputs["crossattn_emb"],
inputs["crossattn_mask"],
inputs["rope_emb_L_1_1_D"],
inputs["adaln_lora_B_3D"],
inputs["original_shape"],
)
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
del inputs
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
x = block(
x,
affline_emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
x_B_D_T_H_W = self.decoder_head(
x_B_T_H_W_D=x_B_T_H_W_D,
emb_B_D=affline_emb_B_D,
crossattn_emb=None,
origin_shape=original_shape,
crossattn_mask=None,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x_B_D_T_H_W

View File

@@ -0,0 +1,208 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from einops import rearrange, repeat
from torch import nn
import math
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
"""
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
Args:
x (torch.Tensor): The input tensor to normalize.
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
eps (float, optional): A small constant to ensure numerical stability during division.
Returns:
torch.Tensor: The normalized tensor.
"""
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class VideoPositionEmb(nn.Module):
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
"""
It delegates the embedding generation to generate_embeddings function.
"""
B_T_H_W_C = x_B_T_H_W_C.shape
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
return embeddings
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
raise NotImplementedError
class VideoRopePosition3DEmb(VideoPositionEmb):
def __init__(
self,
*, # enforce keyword arguments
head_dim: int,
len_h: int,
len_w: int,
len_t: int,
base_fps: int = 24,
h_extrapolation_ratio: float = 1.0,
w_extrapolation_ratio: float = 1.0,
t_extrapolation_ratio: float = 1.0,
device=None,
**kwargs, # used for compatibility with other positional embeddings; unused in this class
):
del kwargs
super().__init__()
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
self.base_fps = base_fps
self.max_h = len_h
self.max_w = len_w
dim = head_dim
dim_h = dim // 6 * 2
dim_w = dim_h
dim_t = dim - 2 * dim_h
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
self.register_buffer(
"dim_spatial_range",
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
persistent=False,
)
self.register_buffer(
"dim_temporal_range",
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
persistent=False,
)
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
def generate_embeddings(
self,
B_T_H_W_C: torch.Size,
fps: Optional[torch.Tensor] = None,
h_ntk_factor: Optional[float] = None,
w_ntk_factor: Optional[float] = None,
t_ntk_factor: Optional[float] = None,
device=None,
dtype=None,
):
"""
Generate embeddings for the given input size.
Args:
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
Returns:
Not specified in the original code snippet.
"""
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
h_theta = 10000.0 * h_ntk_factor
w_theta = 10000.0 * w_ntk_factor
t_theta = 10000.0 * t_ntk_factor
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
B, T, H, W, _ = B_T_H_W_C
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
assert (
uniform_fps or B == 1 or T == 1
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
assert (
H <= self.max_h and W <= self.max_w
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
# apply sequence scaling in temporal dimension
if fps is None: # image case
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
else:
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
em_T_H_W_D = torch.cat(
[
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
]
, dim=-2,
)
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
class LearnablePosEmbAxis(VideoPositionEmb):
def __init__(
self,
*, # enforce keyword arguments
interpolation: str,
model_channels: int,
len_h: int,
len_w: int,
len_t: int,
device=None,
dtype=None,
**kwargs,
):
"""
Args:
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
"""
del kwargs # unused
super().__init__()
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
B, T, H, W, _ = B_T_H_W_C
if self.interpolation == "crop":
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
emb = (
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
)
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
else:
raise ValueError(f"Unknown interpolation method {self.interpolation}")
return normalize(emb, dim=-1, eps=1e-6)

131
comfy/ldm/cosmos/vae.py Normal file
View File

@@ -0,0 +1,131 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
import logging
import torch
from torch import nn
from enum import Enum
import math
from .cosmos_tokenizer.layers3d import (
EncoderFactorized,
DecoderFactorized,
CausalConv3d,
)
class IdentityDistribution(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, parameters):
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
class GaussianDistribution(torch.nn.Module):
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
super().__init__()
self.min_logvar = min_logvar
self.max_logvar = max_logvar
def sample(self, mean, logvar):
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
def forward(self, parameters):
mean, logvar = torch.chunk(parameters, 2, dim=1)
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
return self.sample(mean, logvar), (mean, logvar)
class ContinuousFormulation(Enum):
VAE = GaussianDistribution
AE = IdentityDistribution
class CausalContinuousVideoTokenizer(nn.Module):
def __init__(
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
) -> None:
super().__init__()
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
self.latent_channels = latent_channels
self.sigma_data = 0.5
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
self.encoder = EncoderFactorized(
z_channels=z_factor * z_channels, **kwargs
)
if kwargs.get("temporal_compression", 4) == 4:
kwargs["channels_mult"] = [2, 4]
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
self.decoder = DecoderFactorized(
z_channels=z_channels, **kwargs
)
self.quant_conv = CausalConv3d(
z_factor * z_channels,
z_factor * latent_channels,
kernel_size=1,
padding=0,
)
self.post_quant_conv = CausalConv3d(
latent_channels, z_channels, kernel_size=1, padding=0
)
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters())
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
logging.debug(
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
)
latent_temporal_chunk = 16
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
z, posteriors = self.distribution(moments)
latent_ch = z.shape[1]
latent_t = z.shape[2]
in_dtype = z.dtype
mean = self.latent_mean.view(latent_ch, -1)
std = self.latent_std.view(latent_ch, -1)
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
return ((z - mean) / std) * self.sigma_data
def decode(self, z):
in_dtype = z.dtype
latent_ch = z.shape[1]
latent_t = z.shape[2]
mean = self.latent_mean.view(latent_ch, -1)
std = self.latent_std.view(latent_ch, -1)
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
z = z / self.sigma_data
z = z * std + mean
z = self.post_quant_conv(z)
return self.decoder(z)

View File

@@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)

View File

@@ -5,8 +5,15 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
q_shape = q.shape
k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)

View File

@@ -168,7 +168,7 @@ class Flux(nn.Module):
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:

View File

@@ -159,7 +159,7 @@ class CrossAttention(nn.Module):
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
v = v.transpose(-2, -3).contiguous()
v = v.transpose(-2, -3).contiguous()
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)

View File

@@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)

View File

@@ -53,7 +53,7 @@ class Patchifier(ABC):
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)

View File

@@ -378,7 +378,7 @@ class Decoder(nn.Module):
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
@@ -403,7 +403,7 @@ class Decoder(nn.Module):
)
ada_values = self.last_scale_shift_table[
None, ..., None, None, None
] + embedded_timestep.reshape(
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
batch_size,
2,
-1,
@@ -697,7 +697,7 @@ class ResnetBlock3D(nn.Module):
), "should pass timestep with timestep_conditioning=True"
ada_values = self.scale_shift_table[
None, ..., None, None, None
] + timestep.reshape(
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
batch_size,
4,
-1,
@@ -715,7 +715,7 @@ class ResnetBlock3D(nn.Module):
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale1
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
)
hidden_states = self.norm2(hidden_states)
@@ -731,7 +731,7 @@ class ResnetBlock3D(nn.Module):
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale2
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
)
input_tensor = self.norm3(input_tensor)

View File

@@ -89,7 +89,7 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@@ -142,16 +142,23 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
)
hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
if skip_output_reshape:
hidden_states = hidden_states.unflatten(0, (-1, heads))
else:
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@@ -326,12 +335,18 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
del q, k, v
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return r1
BROKEN_XFORMERS = False
@@ -342,7 +357,7 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@@ -395,9 +410,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
out.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
return out
@@ -408,7 +426,7 @@ else:
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -429,9 +447,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT):
@@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@@ -473,11 +492,15 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = out.reshape(b, -1, heads * dim_head)
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out

View File

@@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
return out
def vae_attention():
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
return pytorch_attention
else:
logging.info("Using split attention in VAE")
return normal_attention
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
@@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
stride=1,
padding=0)
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention
self.optimized_attention = vae_attention()
def forward(self, x):
h_ = x

View File

@@ -9,6 +9,7 @@
import math
import logging
import torch
import torch.nn as nn
import numpy as np
@@ -130,7 +131,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
@@ -142,8 +143,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
logging.info(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas, alphas, alphas_prev

View File

@@ -17,10 +17,10 @@ import math
import logging
try:
from typing import Optional, NamedTuple, List, Protocol
from typing import Optional, NamedTuple, List, Protocol
except ImportError:
from typing import Optional, NamedTuple, List
from typing_extensions import Protocol
from typing import Optional, NamedTuple, List
from typing_extensions import Protocol
from typing import List
@@ -261,7 +261,7 @@ def efficient_dot_product_attention(
value=value,
mask=mask,
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([

380
comfy/ldm/pixart/blocks.py Normal file
View File

@@ -0,0 +1,380 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
from comfy.ldm.modules.attention import optimized_attention
# if model_management.xformers_enabled():
# import xformers.ops
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
# else:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
super(MultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
assert mask is None # TODO?
# # TODO: xformers needs separate mask logic here
# if model_management.xformers_enabled():
# attn_bias = None
# if mask is not None:
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
# else:
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
# attn_mask = None
# mask = torch.ones(())
# if mask is not None and len(mask) > 1:
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
# # But depth doesn't matter as tensors can expand in that dimension
# attn_mask_template = torch.ones(
# [q.shape[2] // B, mask[0]],
# dtype=torch.bool,
# device=q.device
# )
# attn_mask = torch.block_diag(attn_mask_template)
#
# # create a mask on the diagonal for each mask in the batch
# for _ in range(B - 1):
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionKVCompress(nn.Module):
"""Multi-head Attention block with KV token compression and qk norm."""
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
"""
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
self.sr_ratio = sr_ratio
if sr_ratio > 1 and sampling == 'conv':
# Avg Conv Init.
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
# self.sr.weight.data.fill_(1/sr_ratio**2)
# self.sr.bias.data.zero_()
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
if qk_norm:
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
if sampling is None or scale_factor == 1:
return tensor
B, N, C = tensor.shape
if sampling == 'uniform_every':
return tensor[:, ::scale_factor], int(N // scale_factor)
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
new_N = new_H * new_W
if sampling == 'ave':
tensor = F.interpolate(
tensor, scale_factor=1 / scale_factor, mode='nearest'
).permute(0, 2, 3, 1)
elif sampling == 'uniform':
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
elif sampling == 'conv':
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
tensor = self.norm(tensor)
else:
raise ValueError
return tensor.reshape(B, new_N, C).contiguous(), new_N
def forward(self, x, mask=None, HW=None, block_id=None):
B, N, C = x.shape # 2 4096 1152
new_N = N
if HW is None:
H = W = int(N ** 0.5)
else:
H, W = HW
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2)
q = self.q_norm(q)
k = self.k_norm(k)
# KV compression
if self.sr_ratio > 1:
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
if mask is not None:
raise NotImplementedError("Attn mask logic not added for self attention")
# This is never called at the moment
# attn_bias = None
# if mask is not None:
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
# attention 2
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
x = x.view(B, N, C)
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
self.out_channels = out_channels
def forward(self, x, t):
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class MaskFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DecoderLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_decoder(x), shift, scale)
x = self.linear(x)
return x
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs//s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = timestep_embedding(s, self.frequency_embedding_size)
s_emb = self.mlp(s_freq.to(s.dtype))
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
super().__init__()
self.y_proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
dtype=dtype, device=device, operations=operations,
)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class CaptionEmbedderDoubleBr(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
super().__init__()
self.proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
dtype=dtype, device=device, operations=operations,
)
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
self.uncond_prob = uncond_prob
def token_drop(self, global_caption, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return global_caption, caption
def forward(self, caption, train, force_drop_ids=None):
assert caption.shape[2: ] == self.y_embedding.shape
global_caption = caption.mean(dim=2).squeeze()
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
y_embed = self.proj(global_caption)
return y_embed, caption

View File

@@ -0,0 +1,256 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@@ -1,4 +1,5 @@
import importlib
import logging
import torch
from torch import optim
@@ -23,7 +24,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
logging.warning("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -65,7 +66,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
@@ -193,4 +194,4 @@ class AdamWwithEMAandWings(optim.Optimizer):
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
return loss
return loss

View File

@@ -344,7 +344,6 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
@@ -353,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.PixArt):
diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
key_map[key_lora] = to
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
key_map[key_lora] = to
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
key_map[key_lora] = to
if isinstance(model, comfy.model_base.HunyuanDiT):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):

View File

@@ -26,12 +26,14 @@ from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAug
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.genmo.joint_model.asymm_models_joint
import comfy.ldm.aura.mmdit
import comfy.ldm.pixart.pixartms
import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.model_management
import comfy.patcher_extension
@@ -187,9 +189,10 @@ class BaseModel(torch.nn.Module):
if denoise_mask is not None:
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask[:, :1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
num_dim = noise.ndim - 2
denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
@@ -199,12 +202,16 @@ class BaseModel(torch.nn.Module):
if ck == "mask":
cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image":
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
elif ck == "mask_inverted":
cond_concat.append(1.0 - denoise_mask.to(device))
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
cond_concat.append(torch.ones_like(noise)[:, :1])
elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise))
elif ck == "mask_inverted":
cond_concat.append(torch.zeros_like(noise)[:, :1])
data = torch.cat(cond_concat, dim=1)
return data
return None
@@ -292,6 +299,9 @@ class BaseModel(torch.nn.Module):
return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
@@ -718,6 +728,25 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
width = kwargs.get("width", None)
height = kwargs.get("height", None)
if width is not None and height is not None:
out["c_size"] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width]]))
out["c_ar"] = comfy.conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height/width)]]))
return out
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
@@ -754,7 +783,6 @@ class Flux(BaseModel):
mask = torch.ones_like(noise)[:, :1]
mask = torch.mean(mask, dim=1, keepdim=True)
print(mask.shape)
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
@@ -768,7 +796,7 @@ class Flux(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# upscale the attention mask, since now we
# upscale the attention mask, since now we
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
@@ -837,3 +865,30 @@ class HunyuanVideo(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
return out
class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
self.image_to_video = image_to_video
if self.image_to_video:
self.concat_keys = ("mask_inverted",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
sigma_noise_augmentation = 0 #TODO
if sigma_noise_augmentation != 0:
latent_image = latent_image + noise
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)

View File

@@ -203,11 +203,87 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["rope_theta"] = 10000.0
return dit_config
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
# PixArt diffusers
return None
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
patch_size = 2
dit_config = {}
dit_config["num_heads"] = 16
dit_config["patch_size"] = patch_size
dit_config["hidden_size"] = 1152
dit_config["in_channels"] = 4
dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
y_key = "{}y_embedder.y_embedding".format(key_prefix)
if y_key in state_dict_keys:
dit_config["model_max_length"] = state_dict[y_key].shape[0]
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
if ar_key in state_dict_keys:
dit_config["image_model"] = "pixart_alpha"
dit_config["micro_condition"] = True
else:
dit_config["image_model"] = "pixart_sigma"
dit_config["micro_condition"] = False
return dit_config
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "cosmos"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
dit_config["block_config"] = "FA-CA-MLP"
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = False
dit_config["pos_emb_interpolation"] = "crop"
dit_config["block_x_format"] = "THWBD"
dit_config["affline_emb_norm"] = True
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 4096:
# 7B
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 32
dit_config["extra_per_block_abs_pos_emb"] = True
dit_config["rope_h_extrapolation_ratio"] = 1.0
dit_config["rope_w_extrapolation_ratio"] = 1.0
dit_config["rope_t_extrapolation_ratio"] = 2.0
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
else: # 5120
# 14B
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
dit_config["extra_per_block_abs_pos_emb"] = True
dit_config["rope_h_extrapolation_ratio"] = 2.0
dit_config["rope_w_extrapolation_ratio"] = 2.0
dit_config["rope_t_extrapolation_ratio"] = 2.0
dit_config["extra_h_extrapolation_ratio"] = 2.0
dit_config["extra_w_extrapolation_ratio"] = 2.0
dit_config["extra_t_extrapolation_ratio"] = 2.0
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -362,6 +438,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
def unet_prefix_from_state_dict(state_dict):
candidates = ["model.diffusion_model.", #ldm/sgm models
"model.model.", #audio models
"net.", #cosmos
]
counts = {k: 0 for k in candidates}
for k in state_dict:
@@ -540,12 +617,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
@@ -573,6 +650,9 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')

View File

@@ -75,7 +75,7 @@ if args.directml is not None:
try:
import intel_extension_for_pytorch as ipex
_ = torch.xpu.device_count()
xpu_available = torch.xpu.is_available()
xpu_available = xpu_available or torch.xpu.is_available()
except:
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
@@ -86,6 +86,13 @@ try:
except:
pass
try:
import torch_npu # noqa: F401
_ = torch.npu.device_count()
npu_available = torch.npu.is_available()
except:
npu_available = False
if args.cpu:
cpu_state = CPUState.CPU
@@ -97,6 +104,12 @@ def is_intel_xpu():
return True
return False
def is_ascend_npu():
global npu_available
if npu_available:
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@@ -110,6 +123,8 @@ def get_torch_device():
else:
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
else:
return torch.device(torch.cuda.current_device())
@@ -130,6 +145,12 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_npu
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -188,38 +209,44 @@ def is_nvidia():
return True
return False
def is_amd():
global cpu_state
if cpu_state == CPUState.GPU:
if torch.version.hip:
return True
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.2
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
VAE_DTYPES = [torch.float32]
try:
if is_nvidia():
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if is_intel_xpu():
if is_intel_xpu() or is_ascend_npu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
if is_intel_xpu():
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if args.cpu_vae:
VAE_DTYPES = [torch.float32]
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
try:
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
lowvram_available = True
@@ -268,6 +295,8 @@ def get_torch_device_name(device):
return "{}".format(device.type)
elif is_intel_xpu():
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@@ -509,13 +538,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_size = loaded_model.model_memory_required(torch_dev)
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
lowvram_model_memory = 0.1
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
@@ -743,7 +773,6 @@ def vae_offload_device():
return torch.device("cpu")
def vae_dtype(device=None, allowed_dtypes=[]):
global VAE_DTYPES
if args.fp16_vae:
return torch.float16
elif args.bf16_vae:
@@ -752,12 +781,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return torch.float32
for d in allowed_dtypes:
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
return d
if d in VAE_DTYPES:
if d == torch.float16 and should_use_fp16(device):
return d
return VAE_DTYPES[0]
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d
return torch.float32
def get_autocast_device(dev):
if hasattr(dev, 'type'):
@@ -852,6 +883,8 @@ def xformers_enabled():
return False
if is_intel_xpu():
return False
if is_ascend_npu():
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@@ -876,16 +909,23 @@ def pytorch_attention_flash_attention():
return True
if is_intel_xpu():
return True
if is_ascend_npu():
return True
return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
upcast = True
except:
pass
macos_version = mac_version()
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
upcast = True
if upcast:
return torch.float32
else:
@@ -910,6 +950,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@@ -956,17 +1003,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP16:
return True
if device is not None:
if is_device_mps(device):
return True
if FORCE_FP32:
return False
if directml_enabled:
return False
if mps_mode():
if (device is not None and is_device_mps(device)) or mps_mode():
return True
if cpu_mode():
@@ -975,6 +1018,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
if is_ascend_npu():
return True
if torch.version.hip:
return True
@@ -1015,17 +1061,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None:
if is_device_mps(device):
return True
if FORCE_FP32:
return False
if directml_enabled:
return False
if mps_mode():
if (device is not None and is_device_mps(device)) or mps_mode():
if mac_version() < (14,):
return False
return True
if cpu_mode():
@@ -1074,19 +1118,16 @@ def soft_empty_cache(force=False):
torch.mps.empty_cache()
elif is_intel_xpu():
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
elif torch.cuda.is_available():
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())
def resolve_lowvram_weight(weight, model, key): #TODO: remove
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
return weight
#TODO: might be cleaner to put this somewhere else
import threading

View File

@@ -83,7 +83,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
@@ -141,7 +141,7 @@ class AutoPatcherEjector:
self.was_injected = False
self.prev_skip_injection = False
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
def __enter__(self):
self.was_injected = False
self.prev_skip_injection = self.model.skip_injection
@@ -164,7 +164,7 @@ class MemoryCounter:
self.value = initial
self.minimum = minimum
# TODO: add a safe limit besides 0
def use(self, weight: torch.Tensor):
weight_size = weight.nelement() * weight.element_size()
if self.is_useable(weight_size):
@@ -210,7 +210,7 @@ class ModelPatcher:
self.injections: dict[str, list[PatcherInjection]] = {}
self.hook_patches: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
@@ -282,7 +282,7 @@ class ModelPatcher:
n.injections[k] = i.copy()
# hooks
n.hook_patches = create_hook_patches_clone(self.hook_patches)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
for group in self.cached_hook_patches:
n.cached_hook_patches[group] = {}
for k in self.cached_hook_patches[group]:
@@ -402,7 +402,20 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def get_model_object(self, name):
def get_model_object(self, name: str) -> torch.nn.Module:
"""Retrieves a nested attribute from an object using dot notation considering
object patches.
Args:
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
patcher = ModelPatcher()
weight = patcher.get_model_object("layer1.conv.weight")
"""
if name in self.object_patches:
return self.object_patches[name]
else:
@@ -711,7 +724,7 @@ class ModelPatcher:
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
@@ -773,7 +786,7 @@ class ModelPatcher:
return self.model.device
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def cleanup(self):
@@ -789,7 +802,7 @@ class ModelPatcher:
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def remove_callbacks_with_key(self, call_type: str, key: str):
c = self.callbacks.get(call_type, {})
if key in c:
@@ -797,7 +810,7 @@ class ModelPatcher:
def get_callbacks(self, call_type: str, key: str):
return self.callbacks.get(call_type, {}).get(key, [])
def get_all_callbacks(self, call_type: str):
c_list = []
for c in self.callbacks.get(call_type, {}).values():
@@ -810,7 +823,7 @@ class ModelPatcher:
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
w = self.wrappers.get(wrapper_type, {})
if key in w:
@@ -831,7 +844,7 @@ class ModelPatcher:
def remove_attachments(self, key: str):
if key in self.attachments:
self.attachments.pop(key)
def get_attachment(self, key: str):
return self.attachments.get(key, None)
@@ -842,6 +855,9 @@ class ModelPatcher:
if key in self.injections:
self.injections.pop(key)
def get_injections(self, key: str):
return self.injections.get(key, None)
def set_additional_models(self, key: str, models: list['ModelPatcher']):
self.additional_models[key] = models
@@ -851,7 +867,7 @@ class ModelPatcher:
def get_additional_models_with_key(self, key: str):
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
for models in self.additional_models.values():
@@ -906,24 +922,25 @@ class ModelPatcher:
self.model.current_patcher = self
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
def restore_hook_patches(self):
if len(self.hook_patches_backup) > 0:
if self.hook_patches_backup is not None:
self.hook_patches = self.hook_patches_backup
self.hook_patches_backup = {}
self.hook_patches_backup = None
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
@@ -939,25 +956,26 @@ class ModelPatcher:
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided
if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
hook.add_hook_patches(self, model_options, target, registered_hooks)
if registered is None:
registered = comfy.hooks.HookGroup()
# handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook)
else:
registered.add(hook)
if len(weight_hooks_to_register) > 0:
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target, registered_hooks)
hook.add_hook_patches(self, model_options, target_dict, registered)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target)
callback(self, hooks, target_dict, model_options, registered)
return registered
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected():
@@ -975,7 +993,7 @@ class ModelPatcher:
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches: list[tuple] = current_hook_patches.get(key, [])
@@ -1008,11 +1026,11 @@ class ModelPatcher:
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {}
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks)
return {}
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
@@ -1029,7 +1047,7 @@ class ModelPatcher:
if cached_weights is not None:
for key in cached_weights:
if key not in model_sd_keys:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
else:
@@ -1039,7 +1057,7 @@ class ModelPatcher:
original_weights = self.get_key_patches()
for key in relevant_patches:
if key not in model_sd_keys:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
@@ -1063,7 +1081,7 @@ class ModelPatcher:
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
weight, set_func, convert_func = get_key_weight(self.model, key)
weight: torch.Tensor
if key not in self.hook_backup:
@@ -1098,7 +1116,7 @@ class ModelPatcher:
del temp_weight
del out_weight
del weight
def unpatch_hooks(self) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0:
@@ -1107,7 +1125,7 @@ class ModelPatcher:
keys = list(self.hook_backup.keys())
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear()
self.current_hooks = None

View File

@@ -255,9 +255,10 @@ def fp8_linear(self, input):
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
scale_weight = self.scale_weight
@@ -269,23 +270,24 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None

View File

@@ -96,12 +96,12 @@ class WrapperExecutor:
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)
def __call__(self, *args, **kwargs):
"""Calls the next wrapper or original function, whichever is appropriate."""
new_executor = self._create_next_executor()
return new_executor.execute(*args, **kwargs)
def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args)
@@ -121,7 +121,7 @@ class WrapperExecutor:
@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
@classmethod
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx)

View File

@@ -13,7 +13,7 @@ def prepare_noise(latent_image, seed, noise_inds=None):
generator = torch.manual_seed(seed)
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
@@ -25,9 +25,11 @@ def prepare_noise(latent_image, seed, noise_inds=None):
return noises
def fix_empty_latent_channels(model, latent_image):
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1)
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):

View File

@@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
models += [c[model_type]]
return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
full_hooks.add(hook)
if 'control' in c:
cnets.append(c['control'])
@@ -42,7 +40,7 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
if cnet.previous_controlnet is None:
return _list
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
hooks_list = []
cnets = set(cnets)
for base_cnet in cnets:
@@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
full_hooks.add(hook)
return hooks_dict
return full_hooks
def convert_cond(cond):
out = []
@@ -73,13 +70,11 @@ def get_additional_models(conds, dtype):
cnets: list[ControlBase] = []
gligen = []
add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets)
@@ -90,11 +85,20 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
models = control_models + gligen + add_models + hook_models
models = control_models + gligen + add_models
return models, inference_memory
def get_additional_models_from_model_options(model_options: dict[str]=None):
"""loads additional models from registered AddModels hooks"""
models = []
if model_options is not None and "registered_hooks" in model_options:
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
models.extend(hook.models)
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
@@ -102,9 +106,10 @@ def cleanup_additional_models(models):
m.cleanup()
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
real_model: 'BaseModel' = None
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
@@ -123,12 +128,35 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
'''
Registers hooks from conds.
'''
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
hooks = comfy.hooks.HookGroup()
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
# begin registering hooks
registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
# handle all TransformerOptionsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook: comfy.hooks.TransformerOptionsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all AddModelsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all WeightHooks by registering on ModelPatcher
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
# add registered_hooks onto model_options for further reference
if len(registered) > 0:
model_options["registered_hooks"] = registered
# merge original wrappers and callbacks with hooked wrappers and callbacks
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
for wc_name in ["wrappers", "callbacks"]:
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, NamedTuple
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
import torch
from functools import partial
import collections
from comfy import model_management
import math
@@ -144,7 +145,7 @@ def cond_cat(c_list):
return out
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
# need to figure out remaining unmasked area for conds
default_mults = []
for _ in default_conds:
@@ -183,7 +184,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
# replace p's mult with calculated mult
p = p._replace(mult=mult)
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
@@ -218,13 +219,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep)
@@ -375,7 +376,7 @@ class KSamplerX0Inpaint:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
@@ -467,6 +468,13 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
sigmas = adj_idxs.new_zeros(n + 1)
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
return sigmas
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@@ -679,7 +687,7 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
@@ -802,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
for cond in conds_to_modify:
cond['hooks'] = hooks
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
HookGroups that have the same reference.'''
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
# if None were registered, make sure all hooks are cleaned from conds
if registered is None:
for k in conds:
for kk in conds[k]:
kk.pop('hooks', None)
return
# find conds that contain hooks to be replaced - group by common HookGroup refs
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
for k in conds:
for kk in conds[k]:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
if hooks is not None:
if not hooks.is_subset_of(registered):
to_replace = hook_replacement.setdefault(hooks, [])
to_replace.append(kk)
# for each hook to replace, create a new proper HookGroup and assign to all common conds
for hooks, conds_to_modify in hook_replacement.items():
new_hooks = hooks.new_with_common_hooks(registered)
if len(new_hooks) == 0:
new_hooks = None
for kk in conds_to_modify:
kk['hooks'] = new_hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set()
@@ -811,9 +846,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
return len(hooks_set)
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
'''
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
'''
if model_options is None:
return
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
casts = []
if device is not None:
casts.append(device)
if dtype is not None:
casts.append(dtype)
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
for cast in casts:
patch_list[k] = patch_list[k].to(cast)
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
if hasattr(wc_list[i], "to"):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher: 'ModelPatcher' = model_patcher
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
@@ -840,7 +924,9 @@ class CFGGuider:
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
extra_args = {"model_options": extra_model_options, "seed": seed}
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
sampler.sample,
@@ -851,7 +937,7 @@ class CFGGuider:
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
@@ -860,6 +946,7 @@ class CFGGuider:
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
try:
self.model_patcher.pre_run()
@@ -889,6 +976,7 @@ class CFGGuider:
if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
filter_registered_hooks_on_conds(self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample,
self,
@@ -896,6 +984,7 @@ class CFGGuider:
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
@@ -911,29 +1000,37 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas(model_sampling, scheduler_name, steps):
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "normal":
sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
elif scheduler_name == "beta":
sigmas = beta_scheduler(model_sampling, steps)
elif scheduler_name == "linear_quadratic":
sigmas = linear_quadratic_schedule(model_sampling, steps)
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas
class SchedulerHandler(NamedTuple):
handler: Callable[..., torch.Tensor]
# Boolean indicates whether to call the handler like:
# scheduler_function(model_sampling, steps) or
# scheduler_function(n, sigma_min: float, sigma_max: float)
use_ms: bool = True
SCHEDULER_HANDLERS = {
"normal": SchedulerHandler(normal_scheduler),
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
"simple": SchedulerHandler(simple_scheduler),
"ddim_uniform": SchedulerHandler(ddim_scheduler),
"beta": SchedulerHandler(beta_scheduler),
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
}
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
handler = SCHEDULER_HANDLERS.get(scheduler_name)
if handler is None:
err = f"error invalid scheduler {scheduler_name}"
logging.error(err)
raise ValueError(err)
if handler.use_ms:
return handler.handler(model_sampling, steps)
return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
def sampler_object(name):
if name == "uni_pc":

View File

@@ -11,6 +11,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import yaml
import math
@@ -27,12 +28,14 @@ import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.model_patcher
import comfy.lora
@@ -110,7 +113,7 @@ class CLIP:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
self.use_clip_schedule = False
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
def clone(self):
n = CLIP(no_init=True)
@@ -258,6 +261,9 @@ class VAE:
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_index_formula = None
self.upscale_index_formula = None
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -337,7 +343,9 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.upscale_index_formula = (6, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
self.downscale_index_formula = (6, 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
@@ -352,20 +360,37 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.upscale_index_formula = (8, 32, 32)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
ddconfig["time_compress"] = 4
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
self.upscale_index_formula = (8, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 8, 8)
self.downscale_index_formula = (8, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -392,7 +417,7 @@ class VAE:
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode()
@@ -425,7 +450,7 @@ class VAE:
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
@@ -446,7 +471,7 @@ class VAE:
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device)
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in):
pixel_samples = None
@@ -478,7 +503,7 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2
@@ -496,13 +521,20 @@ class VAE:
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3:
if self.latent_dim == 3 and pixel_samples.ndim < 5:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
@@ -531,7 +563,7 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None):
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
@@ -555,7 +587,20 @@ class VAE:
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
samples = self.encode_tiled_3d(pixel_samples, **args)
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples
@@ -574,6 +619,12 @@ class VAE:
except:
return self.downscale_ratio
def temporal_compression_decode(self):
try:
return round(self.upscale_ratio[0](8192) / 8192)
except:
return None
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@@ -604,6 +655,9 @@ class CLIPType(Enum):
MOCHI = 7
LTXV = 8
HUNYUAN_VIDEO = 9
PIXART = 10
COSMOS = 11
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
@@ -620,6 +674,7 @@ class TEModel(Enum):
T5_XL = 5
T5_BASE = 6
LLAMA3_8 = 7
T5_XXL_OLD = 8
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -634,6 +689,8 @@ def detect_te_model(sd):
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
if "model.layers.0.post_attention_layernorm.weight" in sd:
@@ -643,9 +700,10 @@ def detect_te_model(sd):
def t5xxl_detect(clip_data):
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight_name_old = "encoder.block.23.layer.1.DenseReluDense.wi.weight"
for sd in clip_data:
if weight_name in sd:
if weight_name in sd or weight_name_old in sd:
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
return {}
@@ -696,9 +754,15 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
elif clip_type == CLIPType.PIXART:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
elif te_model == TEModel.T5_XXL_OLD:
clip_target.clip = comfy.text_encoders.cosmos.te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.cosmos.CosmosT5Tokenizer
elif te_model == TEModel.T5_XL:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
@@ -857,7 +921,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae, clipvision)
@@ -934,11 +998,11 @@ def load_diffusion_model(unet_path, model_options={}):
return model
def load_unet(unet_path, dtype=None):
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
def load_unet_state_dict(sd, dtype=None):
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):

View File

@@ -37,7 +37,10 @@ class ClipTokenWeightEncoder:
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
if hasattr(self, "gen_empty_tokens"):
to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
else:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
o = self.encode(to_encode)
out, pooled = o[:2]
@@ -385,13 +388,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
else:
if 'weights_only' in torch.load.__code__.co_varnames:
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
except:
embed_out = safe_load_embed_zip(embed_path)
else:
embed = torch.load(embed_path, map_location="cpu")
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
except:
embed_out = safe_load_embed_zip(embed_path)
except Exception:
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
return None

View File

@@ -8,11 +8,13 @@ import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
from . import supported_models_base
from . import latent_formats
@@ -592,6 +594,39 @@ class AuraFlow(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
class PixArtAlpha(supported_models_base.BASE):
unet_config = {
"image_model": "pixart_alpha",
}
sampling_settings = {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
unet_extra_config = {}
latent_format = latent_formats.SD15
memory_usage_factor = 0.5
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.PixArt(self, device=device)
return out.eval()
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL)
class PixArtSigma(PixArtAlpha):
unet_config = {
"image_model": "pixart_sigma",
}
latent_format = latent_formats.SDXL
class HunyuanDiT(supported_models_base.BASE):
unet_config = {
"image_model": "hydit",
@@ -608,6 +643,8 @@ class HunyuanDiT(supported_models_base.BASE):
latent_format = latent_formats.SDXL
memory_usage_factor = 1.3
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@@ -751,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 2.0 #TODO
memory_usage_factor = 1.8 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -787,6 +824,47 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
class CosmosT2V(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos",
"in_channels": 16,
}
sampling_settings = {
"sigma_data": 0.5,
"sigma_max": 80.0,
"sigma_min": 0.002,
}
unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8
memory_usage_factor = 1.6 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosVideo(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class CosmosI2V(CosmosT2V):
unet_config = {
"image_model": "cosmos",
"in_channels": 17,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V]
models += [SVD_img2vid]

View File

@@ -0,0 +1,42 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import os
from transformers import T5TokenizerFast
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
class CosmosT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
class CosmosTEModel_(CosmosT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return CosmosTEModel_

View File

@@ -0,0 +1,42 @@
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_

View File

@@ -227,8 +227,9 @@ class T5(torch.nn.Module):
super().__init__()
self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"]
inner_dim = config_dict["d_kv"] * config_dict["num_heads"]
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.encoder = T5Stack(self.num_layers, model_dim, inner_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.dtype = dtype
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)

View File

@@ -0,0 +1,22 @@
{
"d_ff": 65536,
"d_kv": 128,
"d_model": 1024,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "relu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": false,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 128,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 32128
}

View File

@@ -29,17 +29,29 @@ import itertools
from torch.nn.functional import interpolate
from einops import rearrange
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
from numpy.core.multiarray import scalar
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
@@ -386,6 +398,77 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
return key_map
PIXART_MAP_BASIC = {
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
("x_embedder.proj.weight", "pos_embed.proj.weight"),
("x_embedder.proj.bias", "pos_embed.proj.bias"),
("y_embedder.y_embedding", "caption_projection.y_embedding"),
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
("t_block.1.weight", "adaln_single.linear.weight"),
("t_block.1.bias", "adaln_single.linear.bias"),
("final_layer.linear.weight", "proj_out.weight"),
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.scale_shift_table", "scale_shift_table"),
}
PIXART_MAP_BLOCK = {
("scale_shift_table", "scale_shift_table"),
("attn.proj.weight", "attn1.to_out.0.weight"),
("attn.proj.bias", "attn1.to_out.0.bias"),
("mlp.fc1.weight", "ff.net.0.proj.weight"),
("mlp.fc1.bias", "ff.net.0.proj.bias"),
("mlp.fc2.weight", "ff.net.2.weight"),
("mlp.fc2.bias", "ff.net.2.bias"),
("cross_attn.proj.weight" ,"attn2.to_out.0.weight"),
("cross_attn.proj.bias" ,"attn2.to_out.0.bias"),
}
def pixart_to_diffusers(mmdit_config, output_prefix=""):
key_map = {}
depth = mmdit_config.get("depth", 0)
offset = mmdit_config.get("hidden_size", 1152)
for i in range(depth):
block_from = "transformer_blocks.{}".format(i)
block_to = "{}blocks.{}".format(output_prefix, i)
for end in ("weight", "bias"):
s = "{}.attn1.".format(block_from)
qkv = "{}.attn.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
s = "{}.attn2.".format(block_from)
q = "{}.cross_attn.q_linear.{}".format(block_to, end)
kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
key_map["{}to_q.{}".format(s, end)] = q
key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
for k in PIXART_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
for k in PIXART_MAP_BASIC:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("n_double_layers", 0)
@@ -622,7 +705,25 @@ def copy_to_param(obj, attr, value):
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
def get_attr(obj, attr: str):
"""Retrieves a nested attribute from an object using dot notation.
Args:
obj: The object to get the attribute from
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
model = MyModel()
weight = get_attr(model, "layer1.conv.weight")
# Equivalent to: model.layer1.conv.weight
Important:
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
accessing nested model objects under `ModelPatcher.model`.
"""
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
@@ -631,7 +732,7 @@ def get_attr(obj, attr):
def bislerp(samples, width, height):
def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
c = b1.shape[-1]
#norms
@@ -656,16 +757,16 @@ def bislerp(samples, width, height):
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
#edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
@@ -676,7 +777,7 @@ def bislerp(samples, width, height):
samples = samples.float()
n,c,h,w = samples.shape
h_new, w_new = (height, width)
#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
@@ -751,7 +852,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return rows * cols
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None):
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
dims = len(tile)
if not (isinstance(upscale_amount, (tuple, list))):
@@ -760,6 +861,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
if index_formulas is None:
index_formulas = upscale_amount
if not (isinstance(index_formulas, (tuple, list))):
index_formulas = [index_formulas] * dims
def get_upscale(dim, val):
up = upscale_amount[dim]
if callable(up):
@@ -774,10 +881,26 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
else:
return val / up
def get_upscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return up * val
def get_downscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return val / up
if downscale:
get_scale = get_downscale
get_pos = get_downscale_pos
else:
get_scale = get_upscale
get_pos = get_upscale_pos
def mult_list_upscale(a):
out = []
@@ -800,7 +923,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s
@@ -810,7 +933,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_scale(d, pos)))
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)

View File

@@ -54,8 +54,8 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
def get_input_info(class_def, input_name, valid_inputs=None):
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
@@ -131,7 +131,7 @@ class TopologicalSort:
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)

View File

@@ -1,5 +1,6 @@
import logging
from spandrel import ModelLoader
def load_state_dict(state_dict):
print("WARNING: comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
return ModelLoader().load_from_state_dict(state_dict).eval()

View File

@@ -24,7 +24,7 @@ class AlignYourStepsScheduler:
def INPUT_TYPES(s):
return {"required":
{"model_type": (["SD1", "SDXL", "SVD"], ),
"steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}

View File

@@ -0,0 +1,82 @@
import nodes
import torch
import comfy.model_management
import comfy.utils
class EmptyCosmosLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, )
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
pixel_len = min(pixels.shape[0], length)
padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
padded_pixels[:pixel_len] = pixels[:pixel_len]
latent_len = ((pixel_len - 1) // 8) + 1
latent_temp = vae.encode(padded_pixels)
return latent_temp[:, :, :latent_len]
class CosmosImageToVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is None and end_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
if start_image is not None:
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
latent[:, :, :latent_temp.shape[-3]] = latent_temp
mask[:, :, :latent_temp.shape[-3]] *= 0.0
if end_image is not None:
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
out_latent = {}
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
NODE_CLASS_MAPPINGS = {
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
}

View File

@@ -231,6 +231,24 @@ class FlipSigmas:
sigmas[0] = 0.0001
return (sigmas,)
class SetFirstSigma:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
"sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "set_first_sigma"
def set_first_sigma(self, sigmas, sigma):
sigmas = sigmas.clone()
sigmas[0] = sigma
return (sigmas, )
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@@ -710,6 +728,7 @@ NODE_CLASS_MAPPINGS = {
"SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas,
"SetFirstSigma": SetFirstSigma,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,

View File

@@ -162,7 +162,7 @@ NOISE_LEVELS = {
[14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
],
],
1.15: [
[14.61464119, 0.83188516, 0.02916753],
[14.61464119, 1.84880662, 0.59516323, 0.02916753],
@@ -246,7 +246,7 @@ NOISE_LEVELS = {
[14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
[14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
[14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
],
],
1.35: [
[14.61464119, 0.69515091, 0.02916753],
[14.61464119, 0.95350921, 0.34370604, 0.02916753],

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
import logging
import torch
from collections.abc import Iterable
@@ -32,7 +33,7 @@ class PairConditioningSetProperties:
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@@ -46,7 +47,7 @@ class PairConditioningSetProperties:
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_positive, final_negative)
class PairConditioningSetPropertiesAndCombine:
NodeId = 'PairConditioningSetPropertiesAndCombine'
NodeName = 'Cond Pair Set Props Combine'
@@ -67,7 +68,7 @@ class PairConditioningSetPropertiesAndCombine:
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@@ -158,7 +159,7 @@ class PairConditioningCombine:
"negative_B": ("CONDITIONING",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@@ -185,7 +186,7 @@ class PairConditioningSetDefaultAndCombine:
"hooks": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@@ -197,7 +198,7 @@ class PairConditioningSetDefaultAndCombine:
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
hooks=hooks)
return (final_positive, final_negative)
class ConditioningSetDefaultAndCombine:
NodeId = 'ConditioningSetDefaultCombine'
NodeName = 'Cond Set Default Combine'
@@ -223,7 +224,7 @@ class ConditioningSetDefaultAndCombine:
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
hooks=hooks)
return (final_conditioning,)
class SetClipHooks:
NodeId = 'SetClipHooks'
NodeName = 'Set CLIP Hooks'
@@ -239,13 +240,13 @@ class SetClipHooks:
"hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CLIP",)
CATEGORY = "advanced/hooks/clip"
FUNCTION = "apply_hooks"
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None:
clip = clip.clone()
if apply_to_conds:
@@ -254,7 +255,7 @@ class SetClipHooks:
clip.use_clip_schedule = schedule_clip
if not clip.use_clip_schedule:
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
return (clip,)
class ConditioningTimestepsRange:
@@ -268,7 +269,7 @@ class ConditioningTimestepsRange:
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
@@ -289,7 +290,7 @@ class CreateHookLora:
NodeName = 'Create Hook LoRA'
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
@@ -302,7 +303,7 @@ class CreateHookLora:
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
@@ -315,7 +316,7 @@ class CreateHookLora:
if strength_model == 0 and strength_clip == 0:
return (prev_hooks,)
lora_path = folder_paths.get_full_path("loras", lora_name)
lora = None
if self.loaded_lora is not None:
@@ -325,7 +326,7 @@ class CreateHookLora:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
@@ -347,7 +348,7 @@ class CreateHookLoraModelOnly(CreateHookLora):
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
@@ -377,7 +378,7 @@ class CreateHookModelAsLora:
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
@@ -400,7 +401,7 @@ class CreateHookModelAsLora:
temp = self.loaded_weights
self.loaded_weights = None
del temp
if weights_model is None:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
@@ -425,7 +426,7 @@ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
@@ -454,7 +455,7 @@ class SetHookKeyframes:
"hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/scheduling"
@@ -480,7 +481,7 @@ class CreateHookKeyframe:
"prev_hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
@@ -514,7 +515,7 @@ class CreateHookKeyframesInterpolated:
"prev_hook_kf": ("HOOK_KEYFRAMES",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
@@ -539,7 +540,7 @@ class CreateHookKeyframesInterpolated:
is_first = False
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
if print_keyframes:
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
return (prev_hook_kf,)
class CreateHookKeyframesFromFloats:
@@ -558,7 +559,7 @@ class CreateHookKeyframesFromFloats:
"prev_hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
@@ -579,7 +580,7 @@ class CreateHookKeyframesFromFloats:
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
method=comfy.hooks.InterpolationMethod.LINEAR)
is_first = True
for percent, strength in zip(percents, floats_strength):
guarantee_steps = 0
@@ -588,7 +589,7 @@ class CreateHookKeyframesFromFloats:
is_first = False
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
if print_keyframes:
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
return (prev_hook_kf,)
#------------------------------------------
###########################################
@@ -603,7 +604,7 @@ class SetModelHooksOnCond:
"hooks": ("HOOKS",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/manual"
@@ -629,7 +630,7 @@ class CombineHooks:
"hooks_B": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
@@ -656,7 +657,7 @@ class CombineHooksFour:
"hooks_D": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
@@ -689,7 +690,7 @@ class CombineHooksEight:
"hooks_H": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"

View File

@@ -26,6 +26,7 @@ class Load3D():
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
@@ -37,13 +38,22 @@ class Load3D():
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
imagepath = folder_paths.get_annotated_filepath(image)
if isinstance(image, dict):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
load_image_node = nodes.LoadImage()
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
output_image, output_mask = load_image_node.load_image(image=imagepath)
return output_image, output_mask, model_file,
return output_image, output_mask, model_file,
else:
# to avoid the format is not dict which will happen the FE code is not compatibility to core,
# we need to this to double-check, it can be removed after merged FE into the core
image_path = folder_paths.get_annotated_filepath(image)
load_image_node = nodes.LoadImage()
output_image, output_mask = load_image_node.load_image(image=image_path)
return output_image, output_mask, model_file,
class Load3DAnimation():
@classmethod
@@ -67,6 +77,7 @@ class Load3DAnimation():
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
"animation_speed": (["0.1", "0.5", "1", "1.5", "2"], {"default": "1"}),
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
@@ -78,13 +89,20 @@ class Load3DAnimation():
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
imagepath = folder_paths.get_annotated_filepath(image)
if isinstance(image, dict):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
load_image_node = nodes.LoadImage()
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
output_image, output_mask = load_image_node.load_image(image=imagepath)
return output_image, output_mask, model_file,
return output_image, output_mask, model_file,
else:
image_path = folder_paths.get_annotated_filepath(image)
load_image_node = nodes.LoadImage()
output_image, output_mask = load_image_node.load_image(image=image_path)
return output_image, output_mask, model_file,
class Preview3D():
@classmethod
@@ -98,6 +116,7 @@ class Preview3D():
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
}}
OUTPUT_NODE = True
@@ -121,4 +140,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"Load3D": "Load 3D",
"Load3DAnimation": "Load 3D - Animation",
"Preview3D": "Preview 3D"
}
}

View File

@@ -305,7 +305,7 @@ class FeatherMask:
output[:, -y, :] *= feather_rate
return (output,)
class GrowMask:
@classmethod
def INPUT_TYPES(cls):
@@ -316,7 +316,7 @@ class GrowMask:
"tapered_corners": ("BOOLEAN", {"default": True}),
},
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)

View File

@@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
@@ -206,6 +206,9 @@ class ModelSamplingContinuousEDM:
sigma_data = 1.0
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "edm":
sampling_type = comfy.model_sampling.EDM
sigma_data = 0.5
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "edm_playground_v2.5":

View File

@@ -46,4 +46,4 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"Morphology": "ImageMorphology",
}
}

View File

@@ -64,7 +64,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
def predict_noise(self, x, timestep, model_options={}, seed=None):
# in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
# but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly
positive_cond = self.conds.get("positive", None)
negative_cond = self.conds.get("negative", None)
empty_cond = self.conds.get("empty_negative_prompt", None)
@@ -73,7 +73,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
# normally this would be done in cfg_function, but we skipped
# normally this would be done in cfg_function, but we skipped
# that for efficiency: we can compute the noise predictions in
# a single call to calc_cond_batch() (rather than two)
# so we replicate the hook here

View File

@@ -0,0 +1,24 @@
from nodes import MAX_RESOLUTION
class CLIPTextEncodePixArtAlpha:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
def encode(self, clip, width, height, text):
tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
}

View File

@@ -40,7 +40,7 @@ class LatentRebatch:
return slices, indexable[num * batch_size:]
else:
return slices, None
@staticmethod
def slice_batch(batch, num, batch_size):
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
@@ -81,7 +81,7 @@ class LatentRebatch:
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})

View File

@@ -40,9 +40,8 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
@@ -50,7 +49,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
@@ -99,7 +98,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

View File

@@ -30,4 +30,4 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"WebcamCapture": "Webcam Capture",
}
}

3
comfyui_version.py Normal file
View File

@@ -0,0 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.12"

View File

@@ -62,7 +62,7 @@ class IsChangedCache:
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
@@ -93,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x)
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
@@ -138,11 +138,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
max_len_input = 0
else:
max_len_input = max(len(x) for x in input_data_all.values())
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt:
@@ -168,7 +168,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
process_inputs(input_data_all, 0, input_is_list=input_is_list)
elif max_len_input == 0:
process_inputs({})
else:
else:
for i in range(max_len_input):
input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i)
@@ -196,7 +196,6 @@ def merge_result_data(results, obj):
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = []
uis = []
subgraph_results = []
@@ -226,14 +225,14 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r)
subgraph_results.append((None, r))
if has_subgraph:
output = subgraph_results
elif len(results) > 0:
output = merge_result_data(results, obj)
else:
output = []
ui = dict()
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui, has_subgraph
@@ -556,7 +555,7 @@ def validate_inputs(prompt, item, validated):
received_types = {}
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
if x not in inputs:
if input_category == "required":

View File

@@ -58,7 +58,7 @@ class CacheHelper:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
@@ -305,7 +305,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)

76
main.py
View File

@@ -17,7 +17,7 @@ if __name__ == "__main__":
os.environ['DO_NOT_TRACK'] = '1'
setup_logger(log_level=args.verbose)
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
def apply_custom_paths():
# extra model paths
@@ -63,7 +63,7 @@ def execute_prestartup_script():
spec.loader.exec_module(module)
return True
except Exception as e:
print(f"Failed to execute startup-script: {script_path} / {e}")
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
return False
if args.disable_all_custom_nodes:
@@ -85,14 +85,14 @@ def execute_prestartup_script():
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_prestartup_times) > 0:
print("\nPrestartup times for custom nodes:")
logging.info("\nPrestartup times for custom nodes:")
for n in sorted(node_prestartup_times):
if n[2]:
import_message = ""
else:
import_message = " (PRESTARTUP FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
apply_custom_paths()
execute_prestartup_script()
@@ -114,6 +114,10 @@ if __name__ == "__main__":
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.oneapi_device_selector is not None:
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
@@ -146,9 +150,10 @@ def cuda_malloc_warning():
if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server):
def prompt_worker(q, server_instance):
current_time: float = 0.0
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@@ -163,7 +168,7 @@ def prompt_worker(q, server):
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
server.last_prompt_id = prompt_id
server_instance.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
@@ -173,8 +178,8 @@ def prompt_worker(q, server):
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
@@ -201,21 +206,25 @@ def prompt_worker(q, server):
last_gc_collect = current_time
need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
await asyncio.gather(
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
)
def hijack_progress(server):
def hijack_progress(server_instance):
def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
server.send_sync("progress", progress, server.client_id)
server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
@@ -225,7 +234,11 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":
def start_comfyui(asyncio_loop=None):
"""
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
Returns the event loop, server instance, and a function to start the server asynchronously.
"""
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
logging.info(f"Setting temp directory to: {temp_dir}")
@@ -239,19 +252,20 @@ if __name__ == "__main__":
except:
pass
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
q = execution.PromptQueue(server)
if not asyncio_loop:
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
cuda_malloc_warning()
server.add_routes()
hijack_progress(server)
prompt_server.add_routes()
hijack_progress(prompt_server)
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
if args.quick_test_for_ci:
exit(0)
@@ -268,9 +282,19 @@ if __name__ == "__main__":
webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server
async def start_all():
await prompt_server.setup()
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
# Returning these so that other code can integrate with the ComfyUI loop and server
return asyncio_loop, prompt_server, start_all
if __name__ == "__main__":
# Running directly, just start ComfyUI.
event_loop, _, start_all_func = start_comfyui()
try:
loop.run_until_complete(server.setup())
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
event_loop.run_until_complete(start_all_func())
except KeyboardInterrupt:
logging.info("\nStopped server")

View File

@@ -32,4 +32,4 @@ def update_windows_updater():
except:
pass
shutil.copy(bat_path, dest_bat_path)
print("Updated the windows standalone package updater.")
print("Updated the windows standalone package updater.") # noqa: T201

View File

@@ -51,7 +51,7 @@ class CLIPTextEncode(ComfyNodeABC):
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
}
}
@@ -65,7 +65,7 @@ class CLIPTextEncode(ComfyNodeABC):
def encode(self, clip, text):
tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens), )
class ConditioningCombine:
@classmethod
@@ -269,8 +269,8 @@ class VAEDecode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
"required": {
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
"vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."})
}
}
@@ -293,17 +293,29 @@ class VAEDecodeTiled:
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples, tile_size, overlap=64):
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
overlap = tile_size // 4
if temporal_size < temporal_overlap * 2:
temporal_overlap = temporal_overlap // 2
temporal_compression = vae.temporal_compression_decode()
if temporal_compression is not None:
temporal_size = max(2, temporal_size // temporal_compression)
temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
else:
temporal_size = None
temporal_overlap = None
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
@@ -327,15 +339,17 @@ class VAEEncodeTiled:
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size, overlap):
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap)
return ({"samples":t}, )
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, )
class VAEEncodeForInpaint:
@classmethod
@@ -536,13 +550,13 @@ class CheckpointLoaderSimple:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.")
FUNCTION = "load_checkpoint"
@@ -619,7 +633,7 @@ class LoraLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
@@ -627,7 +641,7 @@ class LoraLoader:
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP")
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
FUNCTION = "load_lora"
@@ -898,16 +912,19 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
def load_clip(self, clip_name, type="stable_diffusion"):
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
@@ -918,11 +935,17 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
elif type == "pixart":
clip_type = comfy.sd.CLIPType.PIXART
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return (clip,)
class DualCLIPLoader:
@@ -931,6 +954,9 @@ class DualCLIPLoader:
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@@ -939,7 +965,7 @@ class DualCLIPLoader:
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
def load_clip(self, clip_name1, clip_name2, type):
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
if type == "sdxl":
@@ -951,7 +977,11 @@ class DualCLIPLoader:
elif type == "hunyuan_video":
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return (clip,)
class CLIPVisionLoader:
@@ -1146,7 +1176,7 @@ class EmptyLatentImage:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"required": {
"width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}),
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."})
@@ -1195,7 +1225,7 @@ class LatentFromBatch:
else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
@@ -1210,7 +1240,7 @@ class RepeatLatentBatch:
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
@@ -1620,15 +1650,15 @@ class LoadImage:
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
img = node_helpers.pillow(Image.open, image_path)
output_images = []
output_masks = []
w, h = None, None
excluded_formats = ['MPO']
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
@@ -1639,10 +1669,10 @@ class LoadImage:
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
@@ -2031,6 +2061,9 @@ NODE_DISPLAY_NAME_MAPPINGS = {
EXTENSION_WEB_DIRS = {}
# Dictionary of successfully loaded module names and associated directories.
LOADED_MODULE_DIRS = {}
def get_module_name(module_path: str) -> str:
"""
@@ -2072,6 +2105,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
if os.path.isdir(web_dir):
@@ -2164,6 +2199,7 @@ def init_builtin_extra_nodes():
"nodes_stable3d.py",
"nodes_sdupscale.py",
"nodes_photomaker.py",
"nodes_pixart.py",
"nodes_cond.py",
"nodes_morphology.py",
"nodes_stable_cascade.py",
@@ -2189,6 +2225,7 @@ def init_builtin_extra_nodes():
"nodes_lt.py",
"nodes_hooks.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
]
import_failed = []
@@ -2217,5 +2254,5 @@ def init_extra_nodes(init_custom_nodes=True):
else:
logging.warning("Please do a: pip install -r requirements.txt")
logging.warning("")
return import_failed

23
pyproject.toml Normal file
View File

@@ -0,0 +1,23 @@
[project]
name = "ComfyUI"
version = "0.3.12"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
[project.urls]
homepage = "https://www.comfy.org/"
repository = "https://github.com/comfyanonymous/ComfyUI"
documentation = "https://docs.comfy.org/"
[tool.ruff]
lint.select = [
"S307", # suspicious-eval-usage
"S102", # exec
"T", # print-usage
"W",
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F",
]
exclude = ["*.ipynb"]

View File

@@ -2,6 +2,7 @@ torch
torchsde
torchvision
torchaudio
numpy>=1.25.0
einops
transformers>=4.28.1
tokenizers>=0.13.3

View File

@@ -1,10 +0,0 @@
# Disable all rules by default
lint.ignore = ["ALL"]
# Enable specific rules
lint.select = [
"S307", # suspicious-eval-usage
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F",
]

View File

@@ -27,9 +27,11 @@ from comfy.cli_args import args
import comfy.utils
import comfy.model_management
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
@@ -43,21 +45,6 @@ async def send_socket_catch_exception(function, message):
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
logging.warning("send error: {}".format(err))
def get_comfyui_version():
comfyui_version = "unknown"
repo_path = os.path.dirname(os.path.realpath(__file__))
try:
import pygit2
repo = pygit2.Repository(repo_path)
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
except Exception:
try:
import subprocess
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
except Exception as e:
logging.warning(f"Failed to get ComfyUI version: {e}")
return comfyui_version.strip()
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
@@ -153,6 +140,7 @@ class PromptServer():
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
@@ -266,7 +254,7 @@ class PromptServer():
def compare_image_hash(filepath, image):
hasher = node_helpers.hasher()
# function to compare hashes of two images to see if it already exists, fix to #3465
if os.path.exists(filepath):
a = hasher()
@@ -516,7 +504,7 @@ class PromptServer():
"os": os.name,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": get_comfyui_version(),
"comfyui_version": __version__,
"python_version": sys.version,
"pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
@@ -697,6 +685,7 @@ class PromptServer():
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.
@@ -713,10 +702,9 @@ class PromptServer():
self.app.add_routes(api_routes)
self.app.add_routes(self.routes)
# Add routes from web extensions.
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([
web.static('/extensions/' + urllib.parse.quote(name), dir),
])
self.app.add_routes([web.static('/extensions/' + name, dir)])
self.app.add_routes([
web.static('/', self.web_root),
@@ -805,7 +793,7 @@ class PromptServer():
async def start(self, address, port, verbose=True, call_on_start=None):
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
async def start_multi_address(self, addresses, call_on_start=None):
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
runner = web.AppRunner(self.app, access_log=None)
await runner.setup()
ssl_ctx = None
@@ -816,7 +804,8 @@ class PromptServer():
keyfile=args.tls_keyfile)
scheme = "https"
logging.info("Starting server\n")
if verbose:
logging.info("Starting server\n")
for addr in addresses:
address = addr[0]
port = addr[1]
@@ -832,7 +821,8 @@ class PromptServer():
else:
address_print = address
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
if verbose:
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
if call_on_start is not None:
call_on_start(scheme, self.address, self.port)

View File

@@ -0,0 +1,40 @@
import pytest
from aiohttp import web
from unittest.mock import patch
from app.custom_node_manager import CustomNodeManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def custom_node_manager():
return CustomNodeManager()
@pytest.fixture
def app(custom_node_manager):
app = web.Application()
routes = web.RouteTableDef()
custom_node_manager.add_routes(routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")])
app.add_routes(routes)
return app
async def test_get_workflow_templates(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
# Setup temporary custom nodes file structure with 1 workflow file
custom_nodes_dir = tmp_path / "custom_nodes"
example_workflows_dir = custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
example_workflows_dir.mkdir(parents=True)
template_file = example_workflows_dir / "workflow1.json"
template_file.write_text('')
with patch('folder_paths.folder_names_and_paths', {
'custom_nodes': ([str(custom_nodes_dir)], None)
}):
response = await client.get('/workflow_templates')
assert response.status == 200
workflows_dict = await response.json()
assert isinstance(workflows_dict, dict)
assert "ComfyUI-TestExtension1" in workflows_dict
assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list)
assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"

View File

@@ -95,4 +95,4 @@ def test_get_save_image_path(temp_dir):
assert filename == "test"
assert counter == 1
assert subfolder == ""
assert filename_prefix == "test"
assert filename_prefix == "test"

View File

@@ -6,8 +6,8 @@ from folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}
@@ -49,4 +49,4 @@ def test_handles_no_extension():
def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []

View File

@@ -89,9 +89,9 @@ async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
client = await aiohttp_client_factory()
try:
resp = await client.get('/files')
print(f"Response received: status {resp.status}")
print(f"Response received: status {resp.status}") # noqa: T201
except Exception as e:
print(f"Exception occurred during GET request: {e}")
print(f"Exception occurred during GET request: {e}") # noqa: T201
raise
assert resp.status != 404, "Route /files does not exist"

Some files were not shown because too many files have changed in this diff Show More