Compare commits
249 Commits
v0.3.13
...
not_requir
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03df573995 | ||
|
|
9ad792f927 | ||
|
|
6fc5dbd52a | ||
|
|
3e8155f7a3 | ||
|
|
8a438115fb | ||
|
|
a14c2fc356 | ||
|
|
9ee6ca99d8 | ||
|
|
bb495cc9b8 | ||
|
|
e51d9ba5fc | ||
|
|
c87a06f934 | ||
|
|
1714a4c158 | ||
|
|
73ecb75a3d | ||
|
|
22ad513c72 | ||
|
|
ed945a1790 | ||
|
|
f9207c6936 | ||
|
|
8ad7477647 | ||
|
|
98bdca4cb2 | ||
|
|
a26da20a76 | ||
|
|
e346d8584e | ||
|
|
ab31b64412 | ||
|
|
fe29739c68 | ||
|
|
e8345a9b7b | ||
|
|
8c6b9f4481 | ||
|
|
cc7e023a4a | ||
|
|
2f7d8159c3 | ||
|
|
70d7242e57 | ||
|
|
49b732afd5 | ||
|
|
3bfe4e5276 | ||
|
|
89e4ea0175 | ||
|
|
3a100b9a55 | ||
|
|
721253cb05 | ||
|
|
3d2e3a6f29 | ||
|
|
2222cf67fd | ||
|
|
ab5413351e | ||
|
|
2b71aab299 | ||
|
|
301e26b131 | ||
|
|
548457bac4 | ||
|
|
0b4584c741 | ||
|
|
a3100c8452 | ||
|
|
832fc02330 | ||
|
|
2d17d8910c | ||
|
|
a40fcfc2d5 | ||
|
|
0a1f8869c9 | ||
|
|
3661c833bc | ||
|
|
84fdaf7b0e | ||
|
|
8edc1f44c1 | ||
|
|
eade1551bb | ||
|
|
581a9991ff | ||
|
|
e471c726e5 | ||
|
|
75c1c757d9 | ||
|
|
ce9b084279 | ||
|
|
2206246055 | ||
|
|
d9fa9d307f | ||
|
|
83e839a89b | ||
|
|
0cf2274699 | ||
|
|
0956107170 | ||
|
|
a4a956dbbd | ||
|
|
8b9ce4ed18 | ||
|
|
3872b43d4b | ||
|
|
32ca0805b7 | ||
|
|
11f1b41bab | ||
|
|
3b19fc76e3 | ||
|
|
50614f1b79 | ||
|
|
6dc7b0bfe3 | ||
|
|
e8e990d6b8 | ||
|
|
2e24a15905 | ||
|
|
fd5297131f | ||
|
|
55a1b09ddc | ||
|
|
3c3988df45 | ||
|
|
7ebd8087ff | ||
|
|
c624c29d66 | ||
|
|
a2448fc527 | ||
|
|
6a0daa79b6 | ||
|
|
9c98c6358b | ||
|
|
7aceb9f91c | ||
|
|
35504e2f93 | ||
|
|
299436cfed | ||
|
|
52e566d2bc | ||
|
|
9b6cd9b874 | ||
|
|
3fc688aebd | ||
|
|
f4411250f3 | ||
|
|
d2a0fb6bb0 | ||
|
|
01015bff16 | ||
|
|
2330754b0e | ||
|
|
bc219a6487 | ||
|
|
94689766ad | ||
|
|
cfbe4b49ca | ||
|
|
ca8efab79f | ||
|
|
65ea778a5e | ||
|
|
db9f2a34fc | ||
|
|
7946049794 | ||
|
|
6f6349b6a7 | ||
|
|
1f138dd382 | ||
|
|
b779349b55 | ||
|
|
35e2dcf5d7 | ||
|
|
67c7184b74 | ||
|
|
6f8e766509 | ||
|
|
e1da98a14a | ||
|
|
a73410aafa | ||
|
|
9aac21f894 | ||
|
|
528d1b3563 | ||
|
|
2bc4b5968f | ||
|
|
7395b0c0d1 | ||
|
|
0952569493 | ||
|
|
29832b3b61 | ||
|
|
be4e760648 | ||
|
|
c3d9cc4592 | ||
|
|
84cc9cb528 | ||
|
|
ebbb920163 | ||
|
|
d60fe0af4a | ||
|
|
5dbd250965 | ||
|
|
4ab1875283 | ||
|
|
11b1f27cb1 | ||
|
|
70e15fd743 | ||
|
|
e1474150de | ||
|
|
e62d72e8ca | ||
|
|
1650cda030 | ||
|
|
a13125840c | ||
|
|
dfa36e6855 | ||
|
|
0124be4d93 | ||
|
|
29a70ca101 | ||
|
|
0bef826a98 | ||
|
|
85ef295069 | ||
|
|
5d84607bf3 | ||
|
|
c1909f350f | ||
|
|
52b3469606 | ||
|
|
889519971f | ||
|
|
76739c23c3 | ||
|
|
a80bc822a2 | ||
|
|
872780d236 | ||
|
|
6d45ffbe23 | ||
|
|
77633ba77d | ||
|
|
30e6cfb1a0 | ||
|
|
dc134b2fdb | ||
|
|
369b079ff6 | ||
|
|
9c9a7f012a | ||
|
|
93fedd92fe | ||
|
|
745b13649b | ||
|
|
2b140654c7 | ||
|
|
65042f7d39 | ||
|
|
7c7c70c400 | ||
|
|
8362199ee7 | ||
|
|
f86c724ef2 | ||
|
|
d6e5d487ad | ||
|
|
6752a826f6 | ||
|
|
04cf0ccb51 | ||
|
|
9af6320ec9 | ||
|
|
6f81cd8973 | ||
|
|
4dc6709307 | ||
|
|
4d55f16ae8 | ||
|
|
cf0b549d48 | ||
|
|
eb4543474b | ||
|
|
1804397952 | ||
|
|
f4dac8ab6f | ||
|
|
b07f116dea | ||
|
|
714f728820 | ||
|
|
92d8d15300 | ||
|
|
89253e9fe5 | ||
|
|
3ea3bc8546 | ||
|
|
8e69e2ddfd | ||
|
|
0270a0b41c | ||
|
|
26c7baf789 | ||
|
|
c37f15f98e | ||
|
|
4bca7367f3 | ||
|
|
b6fefe686b | ||
|
|
fa62287f1f | ||
|
|
0844998db3 | ||
|
|
4ced06b879 | ||
|
|
cb06e9669b | ||
|
|
0c32f82298 | ||
|
|
189da3726d | ||
|
|
9a66bb972d | ||
|
|
ea0f939df3 | ||
|
|
f37551c1d2 | ||
|
|
63023011b9 | ||
|
|
f40076096e | ||
|
|
96d891cb94 | ||
|
|
4553891bbd | ||
|
|
ace899e71a | ||
|
|
aff16532d4 | ||
|
|
b50ab153f9 | ||
|
|
072db3bea6 | ||
|
|
a6deca6d9a | ||
|
|
41c30e92e7 | ||
|
|
f579a740dd | ||
|
|
d37272532c | ||
|
|
12da6ef581 | ||
|
|
29d4384a75 | ||
|
|
c5be423d6b | ||
|
|
b4d3652d88 | ||
|
|
5715be2ca9 | ||
|
|
0d4d9222c6 | ||
|
|
afc85cdeb6 | ||
|
|
acc152b674 | ||
|
|
b07258cef2 | ||
|
|
31e54b7052 | ||
|
|
8c0bae50c3 | ||
|
|
530412cb9d | ||
|
|
61c8c70c6e | ||
|
|
d0399f4343 | ||
|
|
e2919d38b4 | ||
|
|
93c8607d51 | ||
|
|
b3d6ae15b3 | ||
|
|
2e21122aab | ||
|
|
1cd6cd6080 | ||
|
|
d7b4bf21a2 | ||
|
|
042a905c37 | ||
|
|
019c7029ea | ||
|
|
8773ccf74d | ||
|
|
1d5d6586f3 | ||
|
|
35740259de | ||
|
|
ab888e1e0b | ||
|
|
d9f0fcdb0c | ||
|
|
b124256817 | ||
|
|
af4b7c91be | ||
|
|
e57d2282d1 | ||
|
|
4027466c80 | ||
|
|
095d867147 | ||
|
|
caeb27c3a5 | ||
|
|
3d06e1c555 | ||
|
|
43a74c0de1 | ||
|
|
af93c8d1ee | ||
|
|
832e3f5ca3 | ||
|
|
079eccc92a | ||
|
|
b6951768c4 | ||
|
|
fca304debf | ||
|
|
14880e6dba | ||
|
|
f1059b0b82 | ||
|
|
debabccb84 | ||
|
|
37cd448529 | ||
|
|
94f21f9301 | ||
|
|
60653004e5 | ||
|
|
a57d635c5f | ||
|
|
016b219dcc | ||
|
|
8ac2dddeed | ||
|
|
3e880ac709 | ||
|
|
e5ea112a90 | ||
|
|
8d88bfaff9 | ||
|
|
ed4d92b721 | ||
|
|
932ae8d9ca | ||
|
|
44e19a28d3 | ||
|
|
0a0df5f136 | ||
|
|
24d6871e47 | ||
|
|
9e1d301129 | ||
|
|
768e035868 | ||
|
|
669e0497ea | ||
|
|
541dc08547 | ||
|
|
8d8dc9a262 | ||
|
|
2f98c24360 |
@@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
pause
|
||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -22,7 +22,7 @@ on:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "8"
|
||||
default: "9"
|
||||
|
||||
|
||||
jobs:
|
||||
|
||||
58
.github/workflows/update-frontend.yml
vendored
58
.github/workflows/update-frontend.yml
vendored
@@ -1,58 +0,0 @@
|
||||
name: Update Frontend Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "Frontend version to update to (e.g., 1.0.0)"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install wait-for-it
|
||||
# Frontend asset will be downloaded to ComfyUI/web_custom_versions/Comfy-Org_ComfyUI_frontend/{version}
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu --front-end-version Comfy-Org/ComfyUI_frontend@${{ github.event.inputs.version }} 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||
- name: Configure Git
|
||||
run: |
|
||||
git config --global user.name "GitHub Action"
|
||||
git config --global user.email "action@github.com"
|
||||
# Replace existing frontend content with the new version and remove .js.map files
|
||||
# See https://github.com/Comfy-Org/ComfyUI_frontend/issues/2145 for why we remove .js.map files
|
||||
- name: Update frontend content
|
||||
run: |
|
||||
rm -rf web/
|
||||
cp -r web_custom_versions/Comfy-Org_ComfyUI_frontend/${{ github.event.inputs.version }} web/
|
||||
rm web/**/*.js.map
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.PR_BOT_PAT }}
|
||||
commit-message: "Update frontend to v${{ github.event.inputs.version }}"
|
||||
title: "Frontend Update: v${{ github.event.inputs.version }}"
|
||||
body: |
|
||||
Automated PR to update frontend content to version ${{ github.event.inputs.version }}
|
||||
|
||||
This PR was created automatically by the frontend update workflow.
|
||||
branch: release-${{ github.event.inputs.version }}
|
||||
base: master
|
||||
labels: Frontend,dependencies
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "8"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "1"
|
||||
default: "2"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-depth: 30
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||
|
||||
cd ComfyUI_windows_portable_nightly_pytorch
|
||||
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "8"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
10
CODEOWNERS
10
CODEOWNERS
@@ -11,14 +11,14 @@
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
# Frontend assets
|
||||
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
|
||||
|
||||
# Extra nodes
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
|
||||
62
README.md
62
README.md
@@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular diffusion model GUI and backend.**
|
||||
**The most powerful and modular visual AI engine and application.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
@@ -31,10 +31,24 @@
|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
||||
|
||||
## Get Started
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
|
||||
#### [Windows Portable Package](#installing)
|
||||
- Get the latest commits and completely portable.
|
||||
- Available on Windows.
|
||||
|
||||
#### [Manual Install](#manual-install-windows-linux)
|
||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
### [Installing ComfyUI](#installing)
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
@@ -47,12 +61,16 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- 3D Models
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
@@ -120,7 +138,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
# Installing
|
||||
|
||||
## Windows
|
||||
## Windows Portable
|
||||
|
||||
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
|
||||
|
||||
@@ -130,6 +148,8 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643)
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
@@ -138,9 +158,18 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo
|
||||
|
||||
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
|
||||
|
||||
|
||||
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
||||
|
||||
You can install and start ComfyUI using comfy-cli:
|
||||
```bash
|
||||
pip install comfy-cli
|
||||
comfy install
|
||||
```
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@@ -152,7 +181,7 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
|
||||
|
||||
@@ -186,11 +215,11 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@@ -234,6 +263,13 @@ For models compatible with Ascend Extension for PyTorch (torch_npu). To get star
|
||||
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
|
||||
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
|
||||
|
||||
#### Cambricon MLUs
|
||||
|
||||
For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
|
||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
||||
3. Launch ComfyUI by running `python main.py`
|
||||
|
||||
# Running
|
||||
|
||||
@@ -290,6 +326,8 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
||||
|
||||
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
||||
|
||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
@@ -306,7 +344,7 @@ For any bugs, issues, or feature requests related to the frontend, please use th
|
||||
|
||||
The new frontend is now the default for ComfyUI. However, please note:
|
||||
|
||||
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||
1. The frontend in the main ComfyUI repository is updated fortnightly.
|
||||
2. Daily releases are available in the separate frontend repository.
|
||||
|
||||
To use the most up-to-date frontend version:
|
||||
@@ -323,7 +361,7 @@ To use the most up-to-date frontend version:
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||
```
|
||||
|
||||
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from folder_paths import folder_names_and_paths, get_directory_by_type
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
import os
|
||||
|
||||
class InternalRoutes:
|
||||
'''
|
||||
@@ -15,26 +15,10 @@ class InternalRoutes:
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
async def list_files(request):
|
||||
directory_key = request.query.get('directory', '')
|
||||
try:
|
||||
file_list = self.file_service.list_files(directory_key)
|
||||
return web.json_response({"files": file_list})
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
@@ -67,6 +51,20 @@ class InternalRoutes:
|
||||
response[key] = folder_names_and_paths[key][0]
|
||||
return web.json_response(response)
|
||||
|
||||
@self.routes.get('/files/{directory_type}')
|
||||
async def get_files(request: web.Request) -> web.Response:
|
||||
directory_type = request.match_info['directory_type']
|
||||
if directory_type not in ("output", "input", "temp"):
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||
|
||||
class FileService:
|
||||
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||
|
||||
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||
if directory_key not in self.allowed_directories:
|
||||
raise ValueError("Invalid directory key")
|
||||
directory_path: str = self.allowed_directories[directory_key]
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
||||
@@ -9,8 +9,14 @@ class AppSettings():
|
||||
self.user_manager = user_manager
|
||||
|
||||
def get_settings(self, request):
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request, "comfy.settings.json")
|
||||
try:
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request,
|
||||
"comfy.settings.json"
|
||||
)
|
||||
except KeyError as e:
|
||||
logging.error("User settings not found.")
|
||||
raise web.HTTPUnauthorized() from e
|
||||
if os.path.isfile(file):
|
||||
try:
|
||||
with open(file) as f:
|
||||
|
||||
@@ -3,16 +3,69 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import zipfile
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {req_path}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||
""".strip()
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
________________________________________________________________________
|
||||
WARNING WARNING WARNING WARNING WARNING
|
||||
|
||||
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
________________________________________________________________________
|
||||
""".strip()
|
||||
)
|
||||
else:
|
||||
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check frontend version: {e}")
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
@@ -109,9 +162,28 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
|
||||
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-frontend-package is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
@@ -132,7 +204,9 @@ class FrontendManager:
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||
def init_frontend_unsafe(
|
||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||
) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
@@ -148,17 +222,26 @@ class FrontendManager:
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
|
||||
if version.startswith("v"):
|
||||
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||
expected_path = str(
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT)
|
||||
/ f"{repo_owner}_{repo_name}"
|
||||
/ version.lstrip("v")
|
||||
)
|
||||
if os.path.exists(expected_path):
|
||||
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||
logging.info(
|
||||
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
|
||||
)
|
||||
return expected_path
|
||||
|
||||
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||
logging.info(
|
||||
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
|
||||
)
|
||||
|
||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
@@ -201,4 +284,5 @@ class FrontendManager:
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
STARTUP_WARNINGS = []
|
||||
|
||||
|
||||
def log_startup_warning(msg):
|
||||
logging.warning(msg)
|
||||
STARTUP_WARNINGS.append(msg)
|
||||
|
||||
|
||||
def print_startup_warnings():
|
||||
for s in STARTUP_WARNINGS:
|
||||
logging.warning(s)
|
||||
STARTUP_WARNINGS.clear()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import enum
|
||||
import os
|
||||
from typing import Optional
|
||||
import comfy.options
|
||||
|
||||
|
||||
@@ -80,6 +79,7 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
|
||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
|
||||
@@ -101,12 +101,14 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
@@ -130,7 +132,13 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||
|
||||
class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
@@ -161,13 +169,14 @@ parser.add_argument(
|
||||
""",
|
||||
)
|
||||
|
||||
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
||||
"""Validate if the given path is a directory."""
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
def is_valid_directory(path: str) -> str:
|
||||
"""Validate if the given path is a directory, and check permissions."""
|
||||
if not os.path.exists(path):
|
||||
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
|
||||
if not os.path.isdir(path):
|
||||
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
|
||||
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
|
||||
if not os.access(path, os.R_OK):
|
||||
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
|
||||
return path
|
||||
|
||||
parser.add_argument(
|
||||
@@ -179,6 +188,8 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
|
||||
|
||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
@@ -189,3 +200,17 @@ if args.windows_standalone_build:
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
if args.force_fp16:
|
||||
args.fp16_unet = True
|
||||
|
||||
|
||||
# '--fast' is not provided, use an empty set
|
||||
if args.fast is None:
|
||||
args.fast = set()
|
||||
# '--fast' is provided with an empty list, enable all optimizations
|
||||
elif args.fast == []:
|
||||
args.fast = set(PerformanceFeature)
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
@@ -97,14 +97,19 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
if embeds is not None:
|
||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
@@ -115,7 +120,10 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
if num_tokens is not None:
|
||||
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
|
||||
else:
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
@@ -203,6 +211,15 @@ class CLIPVision(torch.nn.Module):
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class LlavaProjector(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@@ -212,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
if "llava3" == config_dict.get("projector_type", None):
|
||||
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
|
||||
else:
|
||||
self.multi_modal_projector = None
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
out = self.visual_projection(x[2])
|
||||
return (x[0], x[1], out)
|
||||
projected = None
|
||||
if self.multi_modal_projector is not None:
|
||||
projected = self.multi_modal_projector(x[1])
|
||||
|
||||
return (x[0], x[1], out, projected)
|
||||
|
||||
@@ -9,6 +9,7 @@ import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
@@ -42,10 +49,11 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
@@ -65,6 +73,7 @@ class ClipVisionModel():
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["mm_projected"] = out[3]
|
||||
return outputs
|
||||
|
||||
def convert_to_transformers(sd, prefix):
|
||||
@@ -101,12 +110,21 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"attention_dropout": 0.0,
|
||||
"dropout": 0.0,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 336,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"model_type": "clip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"projector_type": "llava3",
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
13
comfy/clip_vision_siglip_512.json
Normal file
13
comfy/clip_vision_siglip_512.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 512,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
|
||||
|
||||
|
||||
class UnetApplyFunction(Protocol):
|
||||
@@ -42,4 +42,5 @@ __all__ = [
|
||||
InputTypeDict.__name__,
|
||||
ComfyNodeABC.__name__,
|
||||
CheckLazyMixin.__name__,
|
||||
FileLocator.__name__,
|
||||
]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
@@ -26,6 +27,7 @@ class IO(StrEnum):
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
COMBO = "COMBO"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
@@ -67,90 +69,139 @@ class IO(StrEnum):
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class RemoteInputOptions(TypedDict):
|
||||
route: str
|
||||
"""The route to the remote source."""
|
||||
refresh_button: bool
|
||||
"""Specifies whether to show a refresh button in the UI below the widget."""
|
||||
control_after_refresh: Literal["first", "last"]
|
||||
"""Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
|
||||
timeout: int
|
||||
"""The maximum amount of time to wait for a response from the remote source in milliseconds."""
|
||||
max_retries: int
|
||||
"""The maximum number of retries before aborting the request."""
|
||||
refresh: int
|
||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||
|
||||
|
||||
class MultiSelectOptions(TypedDict):
|
||||
placeholder: NotRequired[str]
|
||||
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
||||
chip: NotRequired[bool]
|
||||
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
|
||||
"""
|
||||
|
||||
default: bool | str | float | int | list | tuple
|
||||
default: NotRequired[bool | str | float | int | list | tuple]
|
||||
"""The default value of the widget"""
|
||||
defaultInput: bool
|
||||
"""Defaults to an input slot rather than a widget"""
|
||||
forceInput: bool
|
||||
"""`defaultInput` and also don't allow converting to a widget"""
|
||||
lazy: bool
|
||||
defaultInput: NotRequired[bool]
|
||||
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
||||
- defaultInput on required inputs should be dropped.
|
||||
- defaultInput on optional inputs should be replaced with forceInput.
|
||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
||||
"""
|
||||
forceInput: NotRequired[bool]
|
||||
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
||||
lazy: NotRequired[bool]
|
||||
"""Declares that this input uses lazy evaluation"""
|
||||
rawLink: bool
|
||||
rawLink: NotRequired[bool]
|
||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||
tooltip: str
|
||||
tooltip: NotRequired[str]
|
||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: float
|
||||
min: NotRequired[float]
|
||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||
max: float
|
||||
max: NotRequired[float]
|
||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||
step: float
|
||||
step: NotRequired[float]
|
||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||
round: float
|
||||
round: NotRequired[float]
|
||||
"""Floats are rounded by this value (``FLOAT``)"""
|
||||
# class InputTypeBoolean(InputTypeOptions):
|
||||
# default: bool
|
||||
label_on: str
|
||||
label_on: NotRequired[str]
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
label_off: NotRequired[str]
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
multiline: bool
|
||||
multiline: NotRequired[bool]
|
||||
"""Use a multiline text box (``STRING``)"""
|
||||
placeholder: str
|
||||
placeholder: NotRequired[str]
|
||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||
# Deprecated:
|
||||
# defaultVal: str
|
||||
dynamicPrompts: bool
|
||||
dynamicPrompts: NotRequired[bool]
|
||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||
# class InputTypeCombo(InputTypeOptions):
|
||||
image_upload: NotRequired[bool]
|
||||
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
|
||||
image_folder: NotRequired[Literal["input", "output", "temp"]]
|
||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||
"""
|
||||
remote: NotRequired[RemoteInputOptions]
|
||||
"""Specifies the configuration for a remote input.
|
||||
Available after ComfyUI frontend v1.9.7
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||
control_after_generate: NotRequired[bool]
|
||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||
options: NotRequired[list[str | int | float]]
|
||||
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||
Prefer:
|
||||
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
||||
Over:
|
||||
[["Option 1", "Option 2", "Option 3"]]
|
||||
"""
|
||||
multi_select: NotRequired[MultiSelectOptions]
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||
|
||||
node_id: Literal["UNIQUE_ID"]
|
||||
node_id: NotRequired[Literal["UNIQUE_ID"]]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
unique_id: Literal["UNIQUE_ID"]
|
||||
unique_id: NotRequired[Literal["UNIQUE_ID"]]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
prompt: Literal["PROMPT"]
|
||||
prompt: NotRequired[Literal["PROMPT"]]
|
||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
||||
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
|
||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||
dynprompt: Literal["DYNPROMPT"]
|
||||
dynprompt: NotRequired[Literal["DYNPROMPT"]]
|
||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||
|
||||
|
||||
class InputTypeDict(TypedDict):
|
||||
"""Provides type hinting for node INPUT_TYPES.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
|
||||
"""
|
||||
|
||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
||||
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
||||
"""Describes all inputs that must be connected for the node to execute."""
|
||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
||||
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
||||
"""Describes inputs which do not need to be connected."""
|
||||
hidden: HiddenInputTypeDict
|
||||
hidden: NotRequired[HiddenInputTypeDict]
|
||||
"""Offers advanced functionality and server-client communication.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
||||
"""
|
||||
|
||||
|
||||
class ComfyNodeABC(ABC):
|
||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
|
||||
"""
|
||||
|
||||
DESCRIPTION: str
|
||||
@@ -167,7 +218,7 @@ class ComfyNodeABC(ABC):
|
||||
CATEGORY: str
|
||||
"""The category of the node, as per the "Add Node" menu.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
|
||||
"""
|
||||
EXPERIMENTAL: bool
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
@@ -181,9 +232,9 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
|
||||
"""
|
||||
return {"required": {}}
|
||||
|
||||
@@ -198,7 +249,7 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
||||
"""
|
||||
INPUT_IS_LIST: bool
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
@@ -209,7 +260,7 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
@@ -227,7 +278,7 @@ class ComfyNodeABC(ABC):
|
||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
||||
specifying which outputs which should be so treated.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
@@ -237,19 +288,19 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
|
||||
"""
|
||||
|
||||
|
||||
@@ -267,8 +318,19 @@ class CheckLazyMixin:
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
return need
|
||||
|
||||
|
||||
class FileLocator(TypedDict):
|
||||
"""Provides type hinting for the file location"""
|
||||
|
||||
filename: str
|
||||
"""The filename of the file."""
|
||||
subfolder: str
|
||||
"""The subfolder of the file."""
|
||||
type: Literal["input", "output", "temp"]
|
||||
"""The root folder of the file."""
|
||||
|
||||
@@ -418,10 +418,7 @@ def controlnet_config(sd, model_options={}):
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
@@ -689,10 +686,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
if supported_inference_dtypes is None:
|
||||
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
|
||||
|
||||
141
comfy/image_encoders/dino2.py
Normal file
141
comfy/image_encoders/dino2.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dense(x)
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
in_features = out_features = dim
|
||||
hidden_features = int(dim * 4)
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
|
||||
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.weights_in(x)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = torch.nn.functional.silu(x1) * x2
|
||||
return self.weights_out(x)
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.projection = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
return x
|
||||
|
||||
|
||||
class Dinov2Model(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
pooled_output = x[:, 0, :]
|
||||
return x, i, pooled_output, None
|
||||
21
comfy/image_encoders/dino2_giant.json
Normal file
21
comfy/image_encoders/dino2_giant.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_size": 1536,
|
||||
"image_size": 518,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"layerscale_value": 1.0,
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_attention_heads": 24,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 40,
|
||||
"patch_size": 14,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": true,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||
@@ -1267,7 +1267,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False):
|
||||
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@@ -1289,53 +1289,60 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
else:
|
||||
gamma = 0
|
||||
sigma_hat = sigmas[i]
|
||||
|
||||
if gamma > 0:
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
|
||||
if sigmas[i + 1] == 0 or old_denoised is None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
if sigma_down == 0 or old_denoised is None:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
d = to_d(x, sigma_hat, uncond_denoised)
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
d = to_d(x, sigmas[i], uncond_denoised)
|
||||
x = denoised + d * sigma_down
|
||||
else:
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
h = t_next - t
|
||||
c2 = (t_prev - t) / h
|
||||
|
||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
|
||||
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
|
||||
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
|
||||
|
||||
if cfg_pp:
|
||||
x = x + (denoised - uncond_denoised)
|
||||
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
|
||||
else:
|
||||
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
old_denoised = denoised
|
||||
if cfg_pp:
|
||||
old_denoised = uncond_denoised
|
||||
else:
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False)
|
||||
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
|
||||
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
@@ -1359,3 +1366,157 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
'''
|
||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s = t + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s = s.neg().exp()
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
'''
|
||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s_1 = t + r_1 * h
|
||||
s_2 = t + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
return x
|
||||
|
||||
@@ -407,3 +407,62 @@ class Cosmos1CV8x8x8(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
|
||||
|
||||
class Wan21(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.1299, -0.1692, 0.2932],
|
||||
[ 0.0671, 0.0406, 0.0442],
|
||||
[ 0.3568, 0.2548, 0.1747],
|
||||
[ 0.0372, 0.2344, 0.1420],
|
||||
[ 0.0313, 0.0189, -0.0328],
|
||||
[ 0.0296, -0.0956, -0.0665],
|
||||
[-0.3477, -0.4059, -0.2925],
|
||||
[ 0.0166, 0.1902, 0.1975],
|
||||
[-0.0412, 0.0267, -0.1364],
|
||||
[-0.1293, 0.0740, 0.1636],
|
||||
[ 0.0680, 0.3019, 0.1128],
|
||||
[ 0.0032, 0.0581, 0.0639],
|
||||
[-0.1251, 0.0927, 0.1699],
|
||||
[ 0.0060, -0.0633, 0.0005],
|
||||
[ 0.3477, 0.2275, 0.2950],
|
||||
[ 0.1984, 0.0913, 0.1861]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
ops.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
ops.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
ops.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
@@ -171,16 +175,16 @@ class StageA(nn.Module):
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
@@ -191,7 +195,7 @@ class StageA(nn.Module):
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
@@ -199,11 +203,11 @@ class StageA(nn.Module):
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
|
||||
@@ -19,6 +19,9 @@ import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
|
||||
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
import comfy.rmsnorm
|
||||
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
@@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
|
||||
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
rms_norm = comfy.rmsnorm.rms_norm
|
||||
|
||||
@@ -105,7 +105,9 @@ class Modulation(nn.Module):
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
if vec.ndim == 2:
|
||||
vec = vec[:, None, :]
|
||||
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
@@ -113,6 +115,20 @@ class Modulation(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
if modulation_dims is None:
|
||||
if m_add is not None:
|
||||
return tensor * m_mult + m_add
|
||||
else:
|
||||
return tensor * m_mult
|
||||
else:
|
||||
for d in modulation_dims:
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||
if m_add is not None:
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||
return tensor
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
@@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
@@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@@ -252,8 +268,11 @@ class LastLayer(nn.Module):
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
vec = vec[:, None, :]
|
||||
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
|
||||
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
@@ -22,7 +23,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
@@ -115,8 +115,11 @@ class Flux(nn.Module):
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
else:
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
|
||||
828
comfy/ldm/hidream/model.py
Normal file
828
comfy/ldm/hidream/model.py
Normal file
@@ -0,0 +1,828 @@
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
|
||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||
return out.float()
|
||||
|
||||
|
||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(2)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=1024,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent)
|
||||
return latent
|
||||
|
||||
|
||||
class PooledEmbed(nn.Module):
|
||||
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, pooled_embed):
|
||||
return self.pooled_embedder(pooled_embed)
|
||||
|
||||
|
||||
class TimestepEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timesteps, wdtype):
|
||||
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||
t_emb = self.timestep_embedder(t_emb)
|
||||
return t_emb
|
||||
|
||||
|
||||
class OutEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, adaln_input):
|
||||
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
|
||||
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||
|
||||
|
||||
class HiDreamAttnProcessor_flashattn:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
dtype = image_tokens.dtype
|
||||
batch_size = image_tokens.shape[0]
|
||||
|
||||
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
|
||||
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
|
||||
value_i = attn.to_v(image_tokens)
|
||||
|
||||
inner_dim = key_i.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
if image_tokens_masks is not None:
|
||||
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
|
||||
|
||||
if not attn.single:
|
||||
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
|
||||
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
|
||||
value_t = attn.to_v_t(text_tokens)
|
||||
|
||||
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
num_image_tokens = query_i.shape[1]
|
||||
num_text_tokens = query_t.shape[1]
|
||||
query = torch.cat([query_i, query_t], dim=1)
|
||||
key = torch.cat([key_i, key_t], dim=1)
|
||||
value = torch.cat([value_i, value_t], dim=1)
|
||||
else:
|
||||
query = query_i
|
||||
key = key_i
|
||||
value = value_i
|
||||
|
||||
if query.shape[-1] == rope.shape[-3] * 2:
|
||||
query, key = apply_rope(query, key, rope)
|
||||
else:
|
||||
query_1, query_2 = query.chunk(2, dim=-1)
|
||||
key_1, key_2 = key.chunk(2, dim=-1)
|
||||
query_1, key_1 = apply_rope(query_1, key_1, rope)
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = attention(query, key, value)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
hidden_states_i = attn.to_out(hidden_states_i)
|
||||
hidden_states_t = attn.to_out_t(hidden_states_t)
|
||||
return hidden_states_i, hidden_states_t
|
||||
else:
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class HiDreamAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
scale_qk: bool = True,
|
||||
eps: float = 1e-5,
|
||||
processor = None,
|
||||
out_dim: int = None,
|
||||
single: bool = False,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
# super(Attention, self).__init__()
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.sliceable_head_dim = heads
|
||||
self.single = single
|
||||
|
||||
linear_cls = operations.Linear
|
||||
self.linear_cls = linear_cls
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
||||
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
|
||||
if not single:
|
||||
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
||||
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
norm_image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: torch.FloatTensor = None,
|
||||
norm_text_tokens: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
image_tokens = norm_image_tokens,
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = norm_text_tokens,
|
||||
rope = rope,
|
||||
)
|
||||
|
||||
|
||||
class FeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * (
|
||||
(hidden_dim + multiple_of - 1) // multiple_of
|
||||
)
|
||||
|
||||
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.top_k = num_activated_experts
|
||||
self.n_routed_experts = num_routed_experts
|
||||
|
||||
self.scoring_func = 'softmax'
|
||||
self.alpha = aux_loss_alpha
|
||||
self.seq_aux = False
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = False
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
pass
|
||||
# import torch.nn.init as init
|
||||
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
### select top-k experts
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
aux_loss = None
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MOEFeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
num_routed_experts: int,
|
||||
num_activated_experts: int,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
|
||||
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
|
||||
self.gate = MoEGate(
|
||||
embed_dim = dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.num_activated_experts = num_activated_experts
|
||||
|
||||
def forward(self, x):
|
||||
wtype = x.dtype
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if True: # self.training: # TODO: check which branch performs faster
|
||||
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||
y = torch.empty_like(x, dtype=wtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape).to(dtype=wtype)
|
||||
#y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.num_activated_experts
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# for fp16 and other dtype
|
||||
expert_cache = expert_cache.to(expert_out.dtype)
|
||||
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
|
||||
return expert_cache
|
||||
|
||||
|
||||
class TextProjection(nn.Module):
|
||||
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear(caption)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlockType:
|
||||
TransformerBlock = 1
|
||||
SingleTransformerBlock = 2
|
||||
|
||||
|
||||
class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
# 1. Attention
|
||||
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.attn1 = HiDreamAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor = HiDreamAttnProcessor_flashattn(),
|
||||
single = True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
if num_routed_experts > 0:
|
||||
self.ff_i = MOEFeedForwardSwiGLU(
|
||||
dim = dim,
|
||||
hidden_dim = 4 * dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
|
||||
|
||||
# 1. MM-Attention
|
||||
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
||||
attn_output_i = self.attn1(
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
rope = rope,
|
||||
)
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
|
||||
image_tokens = ff_output_i + image_tokens
|
||||
return image_tokens
|
||||
|
||||
|
||||
class HiDreamImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
# nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
# nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
# 1. Attention
|
||||
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.attn1 = HiDreamAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor = HiDreamAttnProcessor_flashattn(),
|
||||
single = False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
if num_routed_experts > 0:
|
||||
self.ff_i = MOEFeedForwardSwiGLU(
|
||||
dim = dim,
|
||||
hidden_dim = 4 * dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
|
||||
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
|
||||
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
|
||||
|
||||
# 1. MM-Attention
|
||||
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
||||
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
|
||||
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
|
||||
|
||||
attn_output_i, attn_output_t = self.attn1(
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
norm_text_tokens,
|
||||
rope = rope,
|
||||
)
|
||||
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
text_tokens = gate_msa_t * attn_output_t + text_tokens
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
||||
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
|
||||
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
|
||||
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
|
||||
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
|
||||
image_tokens = ff_output_i + image_tokens
|
||||
text_tokens = ff_output_t + text_tokens
|
||||
return image_tokens, text_tokens
|
||||
|
||||
|
||||
class HiDreamImageBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
block_type: BlockType = BlockType.TransformerBlock,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
block_classes = {
|
||||
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
|
||||
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
|
||||
}
|
||||
self.block = block_classes[block_type](
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
num_routed_experts,
|
||||
num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
return self.block(
|
||||
image_tokens,
|
||||
image_tokens_masks,
|
||||
text_tokens,
|
||||
adaln_input,
|
||||
rope,
|
||||
)
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Optional[int] = None,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 16,
|
||||
num_single_layers: int = 32,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 20,
|
||||
caption_channels: List[int] = None,
|
||||
text_emb_dim: int = 2048,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
axes_dims_rope: Tuple[int, int] = (32, 32),
|
||||
max_resolution: Tuple[int, int] = (128, 128),
|
||||
llama_layers: List[int] = None,
|
||||
image_model=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_single_layers = num_single_layers
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = self.num_attention_heads * self.attention_head_dim
|
||||
self.llama_layers = llama_layers
|
||||
|
||||
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size = patch_size,
|
||||
in_channels = in_channels,
|
||||
out_channels = self.inner_dim,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
|
||||
|
||||
self.double_stream_blocks = nn.ModuleList(
|
||||
[
|
||||
HiDreamImageBlock(
|
||||
dim = self.inner_dim,
|
||||
num_attention_heads = self.num_attention_heads,
|
||||
attention_head_dim = self.attention_head_dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
block_type = BlockType.TransformerBlock,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_stream_blocks = nn.ModuleList(
|
||||
[
|
||||
HiDreamImageBlock(
|
||||
dim = self.inner_dim,
|
||||
num_attention_heads = self.num_attention_heads,
|
||||
attention_head_dim = self.attention_head_dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
block_type = BlockType.SingleTransformerBlock,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for i in range(self.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
||||
caption_projection = []
|
||||
for caption_channel in caption_channels:
|
||||
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
|
||||
|
||||
def expand_timesteps(self, timesteps, batch_size, device):
|
||||
if not torch.is_tensor(timesteps):
|
||||
is_mps = device.type == "mps"
|
||||
if isinstance(timesteps, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
return timesteps
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
|
||||
x_arr = []
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
pH, pW = img_size
|
||||
x_arr.append(
|
||||
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
||||
p1=self.patch_size, p2=self.patch_size)
|
||||
)
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
return x
|
||||
|
||||
def patchify(self, x, max_seq, img_sizes=None):
|
||||
pz2 = self.patch_size * self.patch_size
|
||||
if isinstance(x, torch.Tensor):
|
||||
B = x.shape[0]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
else:
|
||||
B = len(x)
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
|
||||
|
||||
if img_sizes is not None:
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
x_masks[i, 0:img_size[0] * img_size[1]] = 1
|
||||
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
|
||||
elif isinstance(x, torch.Tensor):
|
||||
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
||||
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
|
||||
img_sizes = [[pH, pW]] * B
|
||||
x_masks = None
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x
|
||||
timesteps = t
|
||||
pooled_embeds = y
|
||||
T5_encoder_hidden_states = context
|
||||
|
||||
img_sizes = None
|
||||
|
||||
# spatial forward
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states_type = hidden_states.dtype
|
||||
|
||||
# 0. time
|
||||
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
|
||||
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||
p_embedder = self.p_embedder(pooled_embeds)
|
||||
adaln_input = timesteps + p_embedder
|
||||
|
||||
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
|
||||
if image_tokens_masks is None:
|
||||
pH, pW = img_sizes[0]
|
||||
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# T5_encoder_hidden_states = encoder_hidden_states[0]
|
||||
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
|
||||
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
new_encoder_hidden_states = []
|
||||
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
||||
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
||||
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
||||
new_encoder_hidden_states.append(enc_hidden_state)
|
||||
encoder_hidden_states = new_encoder_hidden_states
|
||||
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
|
||||
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
encoder_hidden_states.append(T5_encoder_hidden_states)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
batch_size,
|
||||
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
|
||||
3,
|
||||
device=img_ids.device, dtype=img_ids.dtype
|
||||
)
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
rope = self.pe_embedder(ids)
|
||||
|
||||
# 2. Blocks
|
||||
block_id = 0
|
||||
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||
for bid, block in enumerate(self.double_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
hidden_states, initial_encoder_hidden_states = block(
|
||||
image_tokens = hidden_states,
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = cur_encoder_hidden_states,
|
||||
adaln_input = adaln_input,
|
||||
rope = rope,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
image_tokens_seq_len = hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
|
||||
hidden_states_seq_len = hidden_states.shape[1]
|
||||
if image_tokens_masks is not None:
|
||||
encoder_attention_mask_ones = torch.ones(
|
||||
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
|
||||
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
|
||||
)
|
||||
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
|
||||
|
||||
for bid, block in enumerate(self.single_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
hidden_states = block(
|
||||
image_tokens=hidden_states,
|
||||
image_tokens_masks=image_tokens_masks,
|
||||
text_tokens=None,
|
||||
adaln_input=adaln_input,
|
||||
rope=rope,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||
output = self.final_layer(hidden_states, adaln_input)
|
||||
output = self.unpatchify(output, img_sizes)
|
||||
return -output
|
||||
135
comfy/ldm/hunyuan3d/model.py
Normal file
135
comfy/ldm/hunyuan3d/model.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from comfy.ldm.flux.layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=64,
|
||||
context_in_dim=1536,
|
||||
hidden_size=1024,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=16,
|
||||
depth=16,
|
||||
depth_single_blocks=32,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
||||
)
|
||||
|
||||
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
|
||||
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
|
||||
)
|
||||
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(depth_single_blocks)
|
||||
]
|
||||
)
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
img = self.latent_in(x)
|
||||
|
||||
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
|
||||
if self.guidance_in is not None:
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
|
||||
|
||||
txt = self.cond_in(txt)
|
||||
pe = None
|
||||
attn_mask = None
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
return img.movedim(-2, -1) * (-1.0)
|
||||
587
comfy/ldm/hunyuan3d/vae.py
Normal file
587
comfy/ldm/hunyuan3d/vae.py
Normal file
@@ -0,0 +1,587 @@
|
||||
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
|
||||
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
[
|
||||
sin(x[..., i]),
|
||||
sin(f_1*x[..., i]),
|
||||
sin(f_2*x[..., i]),
|
||||
...
|
||||
sin(f_N * x[..., i]),
|
||||
cos(x[..., i]),
|
||||
cos(f_1*x[..., i]),
|
||||
cos(f_2*x[..., i]),
|
||||
...
|
||||
cos(f_N * x[..., i]),
|
||||
x[..., i] # only present if include_input is True.
|
||||
], here f_i is the frequency.
|
||||
|
||||
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
||||
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
||||
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
||||
|
||||
Args:
|
||||
num_freqs (int): the number of frequencies, default is 6;
|
||||
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
||||
input_dim (int): the input dimension, default is 3;
|
||||
include_input (bool): include the input tensor or not, default is True.
|
||||
|
||||
Attributes:
|
||||
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
||||
|
||||
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
||||
otherwise, it is input_dim * num_freqs * 2.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_freqs: int = 6,
|
||||
logspace: bool = True,
|
||||
input_dim: int = 3,
|
||||
include_input: bool = True,
|
||||
include_pi: bool = True) -> None:
|
||||
|
||||
"""The initialization"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if logspace:
|
||||
frequencies = 2.0 ** torch.arange(
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
frequencies = torch.linspace(
|
||||
1.0,
|
||||
2.0 ** (num_freqs - 1),
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
if include_pi:
|
||||
frequencies *= torch.pi
|
||||
|
||||
self.register_buffer("frequencies", frequencies, persistent=False)
|
||||
self.include_input = include_input
|
||||
self.num_freqs = num_freqs
|
||||
|
||||
self.out_dim = self.get_dims(input_dim)
|
||||
|
||||
def get_dims(self, input_dim):
|
||||
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
||||
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
||||
|
||||
return out_dim
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" Forward process.
|
||||
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
|
||||
Returns:
|
||||
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
||||
where temp is 1 if include_input is True and 0 otherwise.
|
||||
"""
|
||||
|
||||
if self.num_freqs > 0:
|
||||
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
|
||||
if self.include_input:
|
||||
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if self.drop_prob == 0. or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and self.scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
def extra_repr(self):
|
||||
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, *,
|
||||
width: int,
|
||||
expand_ratio: int = 4,
|
||||
output_width: int = None,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.c_fc = ops.Linear(width, width * expand_ratio)
|
||||
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
||||
self.gelu = nn.GELU()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
data_width: Optional[int] = None,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
kv_cache: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.data_width = width if data_width is None else data_width
|
||||
self.c_q = ops.Linear(width, width, bias=qkv_bias)
|
||||
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadCrossAttention(
|
||||
heads=heads,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.kv_cache = kv_cache
|
||||
self.data = None
|
||||
|
||||
def forward(self, x, data):
|
||||
x = self.c_q(x)
|
||||
if self.kv_cache:
|
||||
if self.data is None:
|
||||
self.data = self.c_kv(data)
|
||||
logging.info('Save kv cache,this should be called only once for one mesh')
|
||||
data = self.data
|
||||
else:
|
||||
data = self.c_kv(data)
|
||||
x = self.attention(x, data)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
data_width: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_width is None:
|
||||
data_width = width
|
||||
|
||||
self.attn = MultiheadCrossAttention(
|
||||
width=width,
|
||||
heads=heads,
|
||||
data_width=data_width,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
||||
|
||||
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class QKVMultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, qkv):
|
||||
bs, n_ctx, width = qkv.shape
|
||||
attn_ch = width // self.heads // 3
|
||||
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||||
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
heads=heads,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_qkv(x)
|
||||
x = self.attention(x)
|
||||
x = self.drop_path(self.c_proj(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = MultiheadAttention(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
for block in self.resblocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
downsample_ratio: int = 1,
|
||||
enable_ln_post: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enable_ln_post = enable_ln_post
|
||||
self.fourier_embedder = fourier_embedder
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
mlp_expand_ratio=mlp_expand_ratio,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
if self.enable_ln_post:
|
||||
self.ln_post = ops.LayerNorm(width)
|
||||
self.output_proj = ops.Linear(width, out_channels)
|
||||
self.label_type = label_type
|
||||
self.count = 0
|
||||
|
||||
def forward(self, queries=None, query_embeddings=None, latents=None):
|
||||
if query_embeddings is None:
|
||||
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
||||
self.count += query_embeddings.shape[1]
|
||||
if self.downsample_ratio != 1:
|
||||
latents = self.latents_proj(latents)
|
||||
x = self.cross_attn_decoder(query_embeddings, latents)
|
||||
if self.enable_ln_post:
|
||||
x = self.ln_post(x)
|
||||
occ = self.output_proj(x)
|
||||
return occ
|
||||
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=width,
|
||||
layers=num_decoder_layers,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
out_channels=1,
|
||||
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
||||
downsample_ratio=geo_decoder_downsample_ratio,
|
||||
enable_ln_post=self.geo_decoder_ln_post,
|
||||
width=width // geo_decoder_downsample_ratio,
|
||||
heads=heads // geo_decoder_downsample_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
label_type=label_type,
|
||||
)
|
||||
|
||||
self.volume_decoder = VanillaVolumeDecoder()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def decode(self, latents, **kwargs):
|
||||
latents = self.post_kl(latents.movedim(-2, -1))
|
||||
latents = self.transformer(latents)
|
||||
|
||||
bounds = kwargs.get("bounds", 1.01)
|
||||
num_chunks = kwargs.get("num_chunks", 8000)
|
||||
octree_resolution = kwargs.get("octree_resolution", 256)
|
||||
enable_pbar = kwargs.get("enable_pbar", True)
|
||||
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
||||
@@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
if self.params.guidance_embed:
|
||||
if guidance is not None:
|
||||
@@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -303,17 +314,17 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = img[:, : img_len]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
@@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||
return out
|
||||
|
||||
@@ -7,7 +7,7 @@ from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module):
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.vae_scale_factors = vae_scale_factors
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
@@ -416,42 +420,23 @@ class LTXVModel(torch.nn.Module):
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
if guiding_latent_noise_scale > 0:
|
||||
if self.generator is None:
|
||||
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
||||
elif self.generator.device != x.device:
|
||||
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
||||
|
||||
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
||||
scale = guiding_latent_noise_scale * (input_ts ** 2)
|
||||
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
||||
|
||||
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
|
||||
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
x, latent_coords = self.patchifier.patchify(x)
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords=latent_coords,
|
||||
scale_factors=self.vae_scale_factors,
|
||||
causal_fix=self.causal_temporal_positioning,
|
||||
)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
@@ -459,7 +444,7 @@ class LTXVModel(torch.nn.Module):
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
@@ -519,8 +504,4 @@ class LTXVModel(torch.nn.Module):
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
||||
|
||||
@@ -6,16 +6,29 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
def latent_to_pixel_coords(
|
||||
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
|
||||
) -> Tensor:
|
||||
"""
|
||||
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
||||
configuration.
|
||||
Args:
|
||||
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
||||
containing the latent corner coordinates of each token.
|
||||
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
|
||||
causal_fix (bool): Whether to take into account the different temporal scale
|
||||
of the first frame. Default = False for backwards compatibility.
|
||||
Returns:
|
||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||
"""
|
||||
pixel_coords = (
|
||||
latent_coords
|
||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||
)
|
||||
if causal_fix:
|
||||
# Fix temporal scale for first frame to 1 due to causality
|
||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||
return pixel_coords
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
@@ -44,29 +57,26 @@ class Patchifier(ABC):
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
def get_latent_coords(
|
||||
self, latent_num_frames, latent_height, latent_width, batch_size, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
"""
|
||||
Return a tensor of shape [batch_size, 3, num_patches] containing the
|
||||
top-left corner latent coordinates of each latent patch.
|
||||
The tensor is repeated for each batch element.
|
||||
"""
|
||||
latent_sample_coords = torch.meshgrid(
|
||||
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
|
||||
torch.arange(0, latent_height, self._patch_size[1], device=device),
|
||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_coords = rearrange(
|
||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
return latent_coords
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
@@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
b, _, f, h, w = latents.shape
|
||||
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
@@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
return latents, latent_coords
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
|
||||
@@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
padding_mode=spatial_padding_mode,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class Encoder(nn.Module):
|
||||
@@ -32,7 +34,7 @@ class Encoder(nn.Module):
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -40,12 +42,13 @@ class Encoder(nn.Module):
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
@@ -65,6 +68,7 @@ class Encoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
@@ -82,6 +86,7 @@ class Encoder(nn.Module):
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
@@ -92,6 +97,7 @@ class Encoder(nn.Module):
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
@@ -101,6 +107,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
@@ -110,6 +117,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
@@ -119,6 +127,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
@@ -129,6 +138,34 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(2, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
@@ -152,10 +189,18 @@ class Encoder(nn.Module):
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var == "constant":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
dims,
|
||||
output_channel,
|
||||
conv_out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -197,6 +242,15 @@ class Encoder(nn.Module):
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
elif self.latent_log_var == "constant":
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = (
|
||||
-30
|
||||
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
|
||||
sample = torch.cat(
|
||||
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
@@ -231,7 +285,7 @@ class Decoder(nn.Module):
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
@@ -239,6 +293,7 @@ class Decoder(nn.Module):
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
@@ -264,6 +319,7 @@ class Decoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
@@ -283,6 +339,7 @@ class Decoder(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
@@ -294,6 +351,7 @@ class Decoder(nn.Module):
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
@@ -306,14 +364,21 @@ class Decoder(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
@@ -323,6 +388,7 @@ class Decoder(nn.Module):
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
@@ -340,7 +406,13 @@ class Decoder(nn.Module):
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
dims,
|
||||
output_channel,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
inject_noise (`bool`, *optional*, defaults to `False`):
|
||||
Whether to inject noise into the hidden states.
|
||||
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
||||
Whether to condition the hidden states on the timestep.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
@@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module):
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
@@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
timestep_embed = None
|
||||
if self.timestep_conditioning:
|
||||
@@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.group_size = in_channels * math.prod(stride) // out_channels
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels // math.prod(stride),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.stride[0] == 2:
|
||||
x = torch.cat(
|
||||
[x[:, :, :1, :, :], x], dim=2
|
||||
) # duplicate first frames for padding
|
||||
|
||||
# skip connection
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
|
||||
x_in = x_in.mean(dim=2)
|
||||
|
||||
# conv
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
x = x + x_in
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(
|
||||
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
||||
self,
|
||||
dims,
|
||||
in_channels,
|
||||
stride,
|
||||
residual=False,
|
||||
out_channels_reduction_factor=1,
|
||||
spatial_padding_mode="zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
@@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
@@ -558,7 +695,7 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
@@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module):
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
@@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
@@ -801,9 +941,44 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self, version=0):
|
||||
def __init__(self, version=0, config=None):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.guess_config(version)
|
||||
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=self.timestep_conditioning,
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def guess_config(self, version):
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
@@ -830,7 +1005,7 @@ class VideoVAE(nn.Module):
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
else:
|
||||
elif version == 1:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
@@ -866,37 +1041,47 @@ class VideoVAE(nn.Module):
|
||||
"causal_decoder": False,
|
||||
"timestep_conditioning": True,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=config.get("timestep_conditioning", False),
|
||||
)
|
||||
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.per_channel_statistics = processor()
|
||||
else:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_space_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_time_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}]
|
||||
],
|
||||
"decoder_blocks": [
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}]
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
"timestep_conditioning": True
|
||||
}
|
||||
return config
|
||||
|
||||
def encode(self, x):
|
||||
frames_count = x.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
|
||||
@@ -17,7 +17,11 @@ def make_conv_nd(
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
spatial_padding_mode="zeros",
|
||||
temporal_padding_mode="zeros",
|
||||
):
|
||||
if not (spatial_padding_mode == temporal_padding_mode or causal):
|
||||
raise NotImplementedError("spatial and temporal padding modes must be equal")
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
@@ -28,6 +32,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
@@ -40,6 +45,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
@@ -50,6 +56,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
@@ -59,6 +66,7 @@ def make_conv_nd(
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
@@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.padding_mode = padding_mode
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
@@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
@@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
return x
|
||||
@@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
x = F.conv2d(
|
||||
x,
|
||||
weight1,
|
||||
self.bias1,
|
||||
stride1,
|
||||
padding1,
|
||||
dilation1,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
@@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = F.conv1d(
|
||||
x,
|
||||
weight2,
|
||||
self.bias2,
|
||||
stride2,
|
||||
padding2,
|
||||
dilation2,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
622
comfy/ldm/lumina/model.py
Normal file
622
comfy/ldm/lumina/model.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
|
||||
#############################################################################
|
||||
# Core NextDiT Model #
|
||||
#############################################################################
|
||||
|
||||
|
||||
class JointAttention(nn.Module):
|
||||
"""Multi-head attention module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
operation_settings={},
|
||||
):
|
||||
"""
|
||||
Initialize the Attention module.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input dimensions.
|
||||
n_heads (int): Number of heads.
|
||||
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
||||
self.n_local_heads = n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
self.qkv = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.out = operation_settings.get("operations").Linear(
|
||||
n_heads * self.head_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
"""
|
||||
|
||||
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x_in.shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
x:
|
||||
x_mask:
|
||||
freqs_cis:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
xq, xk, xv = torch.split(
|
||||
self.qkv(x),
|
||||
[
|
||||
self.n_local_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
||||
|
||||
return self.out(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
operation_settings={},
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple
|
||||
of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
|
||||
dimension. Defaults to None.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.w2 = operation_settings.get("operations").Linear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.w3 = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
# @torch.compile
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
class JointTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
operation_settings={},
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a TransformerBlock.
|
||||
|
||||
Args:
|
||||
layer_id (int): Identifier for the layer.
|
||||
dim (int): Embedding dimension of the input features.
|
||||
n_heads (int): Number of attention heads.
|
||||
n_kv_heads (Optional[int]): Number of attention heads in key and
|
||||
value features (if using GQA), or set to None for the same as
|
||||
query.
|
||||
multiple_of (int):
|
||||
ffn_dim_multiplier (float):
|
||||
norm_eps (float):
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the TransformerBlock.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after applying attention and
|
||||
feedforward layers.
|
||||
|
||||
"""
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
x = x + self.attention_norm2(
|
||||
self.attention(
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of NextDiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
||||
super().__init__()
|
||||
self.norm_final = operation_settings.get("operations").LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.linear = operation_settings.get("operations").Linear(
|
||||
hidden_size,
|
||||
patch_size * patch_size * out_channels,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(hidden_size, 1024),
|
||||
hidden_size,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = self.adaLN_modulation(c)
|
||||
x = modulate(self.norm_final(x), scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class NextDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
dim: int = 4096,
|
||||
n_layers: int = 32,
|
||||
n_refiner_layers: int = 2,
|
||||
n_heads: int = 32,
|
||||
n_kv_heads: Optional[int] = None,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = False,
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (1, 512, 512),
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.x_embedder = operation_settings.get("operations").Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
out_features=dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
||||
operation_settings.get("operations").Linear(
|
||||
cap_feat_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
def unpatchify(
|
||||
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
pH = pW = self.patch_size
|
||||
imgs = []
|
||||
for i in range(x.size(0)):
|
||||
H, W = img_size[i]
|
||||
begin = cap_size[i]
|
||||
end = begin + (H // pH) * (W // pW)
|
||||
imgs.append(
|
||||
x[i][begin:end]
|
||||
.view(H // pH, W // pW, pH, pW, self.out_channels)
|
||||
.permute(4, 0, 2, 1, 3)
|
||||
.flatten(3, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
|
||||
if return_tensor:
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
|
||||
if cap_mask is not None:
|
||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
||||
else:
|
||||
l_effective_cap_len = [num_tokens] * bsz
|
||||
|
||||
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
||||
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
||||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
||||
|
||||
max_seq_len = max(
|
||||
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
||||
)
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
||||
|
||||
# build freqs_cis for cap and image individually
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
# cap_freqs_cis_shape[1] = max_cap_len
|
||||
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
for i in range(bsz):
|
||||
img = x[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_x.append(img)
|
||||
x = flat_x
|
||||
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
||||
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
||||
for i in range(bsz):
|
||||
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
||||
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
||||
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
||||
else:
|
||||
mask = None
|
||||
|
||||
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
|
||||
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
||||
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
||||
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
"""
|
||||
Forward pass of NextDiT.
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of text tokens/features
|
||||
"""
|
||||
|
||||
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
|
||||
return -x
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import math
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
@@ -16,7 +18,18 @@ if model_management.xformers_enabled():
|
||||
import xformers.ops
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
from sageattention import sageattn
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
exit(-1)
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
@@ -24,38 +37,24 @@ ops = comfy.ops.disable_weight_init
|
||||
|
||||
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||
|
||||
def get_attn_precision(attn_precision):
|
||||
def get_attn_precision(attn_precision, current_dtype):
|
||||
if args.dont_upcast_attention:
|
||||
return None
|
||||
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
|
||||
return FORCE_UPCAST_ATTENTION_DTYPE
|
||||
|
||||
if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
|
||||
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
|
||||
return attn_precision
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
||||
@@ -90,7 +89,7 @@ def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
@@ -159,7 +158,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = query.shape
|
||||
@@ -229,7 +228,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
@@ -472,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
tensor_layout = "HND"
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
@@ -480,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
tensor_layout="NHD"
|
||||
tensor_layout = "NHD"
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
@@ -490,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
try:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
if tensor_layout == "NHD":
|
||||
q, k, v = map(
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
||||
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
@@ -504,6 +513,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||
|
||||
|
||||
@flash_attn_wrapper.register_fake
|
||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||
# Output shape is the same as q
|
||||
return q.new_empty(q.shape)
|
||||
except AttributeError as error:
|
||||
FLASH_ATTN_ERROR = error
|
||||
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
assert mask is None
|
||||
out = flash_attn_wrapper(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
).transpose(1, 2)
|
||||
except Exception as e:
|
||||
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
@@ -512,6 +578,9 @@ if model_management.sage_attention_enabled():
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.flash_attention_enabled():
|
||||
logging.info("Using Flash Attention")
|
||||
optimized_attention = attention_flash
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention")
|
||||
optimized_attention = attention_pytorch
|
||||
@@ -778,6 +847,7 @@ class SpatialTransformer(nn.Module):
|
||||
if not isinstance(context, list):
|
||||
context = [context] * len(self.transformer_blocks)
|
||||
b, c, h, w = x.shape
|
||||
transformer_options["activations_shape"] = list(x.shape)
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
@@ -893,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
transformer_options={}
|
||||
) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
transformer_options["activations_shape"] = list(x.shape)
|
||||
x_in = x
|
||||
spatial_context = None
|
||||
if exists(context):
|
||||
|
||||
@@ -321,7 +321,7 @@ class SelfAttention(nn.Module):
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
|
||||
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
@@ -297,7 +297,7 @@ def vae_attention():
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
return xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
elif model_management.pytorch_attention_enabled_vae():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
return pytorch_attention
|
||||
else:
|
||||
|
||||
485
comfy/ldm/wan/model.py
Normal file
485
comfy/ldm/wan/model.py
Normal file
@@ -0,0 +1,485 @@
|
||||
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.type(torch.float32)
|
||||
|
||||
# calculation
|
||||
sinusoid = torch.outer(
|
||||
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n * d)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def forward(self, x, context):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
"""
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(context))
|
||||
v = self.v(context)
|
||||
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanI2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
|
||||
|
||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
"""
|
||||
context_img = context[:, :257]
|
||||
context = context[:, 257:]
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(context))
|
||||
v = self.v(context)
|
||||
k_img = self.norm_k_img(self.k_img(context_img))
|
||||
v_img = self.v_img(context_img)
|
||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
|
||||
# output
|
||||
x = x + img_x
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
WAN_CROSSATTENTION_CLASSES = {
|
||||
't2v_cross_attn': WanT2VCrossAttention,
|
||||
'i2v_cross_attn': WanI2VCrossAttention,
|
||||
}
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
||||
eps, operation_settings=operation_settings)
|
||||
self.norm3 = operation_settings.get("operations").LayerNorm(
|
||||
dim, eps,
|
||||
elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
||||
num_heads,
|
||||
(-1, -1),
|
||||
qk_norm,
|
||||
eps, operation_settings=operation_settings)
|
||||
self.norm2 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.ffn = nn.Sequential(
|
||||
operation_settings.get("operations").Linear(dim, ffn_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
|
||||
operation_settings.get("operations").Linear(ffn_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
freqs,
|
||||
context,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
e(Tensor): Shape [B, 6, C]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + e[1]) + e[0],
|
||||
freqs)
|
||||
|
||||
x = x + y * e[2]
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
||||
x = x + y * e[5]
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.head = operation_settings.get("operations").Linear(dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.empty(1, 2, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
def forward(self, x, e):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
e(Tensor): Shape [B, C]
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
return x
|
||||
|
||||
|
||||
class MLPProj(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, operation_settings={}):
|
||||
super().__init__()
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
operation_settings.get("operations").LayerNorm(in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear(in_dim, in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
def forward(self, image_embeds):
|
||||
clip_extra_context_tokens = self.proj(image_embeds)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class WanModel(torch.nn.Module):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
r"""
|
||||
Initialize the diffusion model backbone.
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
Fixed length for text embeddings
|
||||
in_dim (`int`, *optional*, defaults to 16):
|
||||
Input video channels (C_in)
|
||||
dim (`int`, *optional*, defaults to 2048):
|
||||
Hidden dimension of the transformer
|
||||
ffn_dim (`int`, *optional*, defaults to 8192):
|
||||
Intermediate dimension in feed-forward network
|
||||
freq_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension for sinusoidal time embeddings
|
||||
text_dim (`int`, *optional*, defaults to 4096):
|
||||
Input dimension for text embeddings
|
||||
out_dim (`int`, *optional*, defaults to 16):
|
||||
Output video channels (C_out)
|
||||
num_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads
|
||||
num_layers (`int`, *optional*, defaults to 32):
|
||||
Number of transformer blocks
|
||||
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
||||
Window size for local attention (-1 indicates global attention)
|
||||
qk_norm (`bool`, *optional*, defaults to True):
|
||||
Enable query/key normalization
|
||||
cross_attn_norm (`bool`, *optional*, defaults to False):
|
||||
Enable cross-attention normalization
|
||||
eps (`float`, *optional*, defaults to 1e-6):
|
||||
Epsilon value for normalization layers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
assert model_type in ['t2v', 'i2v']
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = operations.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
|
||||
self.text_embedding = nn.Sequential(
|
||||
operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
|
||||
operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
self.time_embedding = nn.Sequential(
|
||||
operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
# blocks
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
|
||||
|
||||
d = dim // num_heads
|
||||
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
||||
|
||||
if model_type == 'i2v':
|
||||
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
|
||||
else:
|
||||
self.img_emb = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (Tensor):
|
||||
List of input video tensors with shape [B, C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [B, L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
clip_fea (Tensor, *optional*):
|
||||
CLIP image features for image-to-video mode
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
Reconstruct video tensors from patch embeddings.
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||||
grid_sizes (Tensor):
|
||||
Original spatial-temporal grid dimensions before patching,
|
||||
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
|
||||
c = self.out_dim
|
||||
u = x
|
||||
b = u.shape[0]
|
||||
u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
|
||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||
return u
|
||||
567
comfy/ldm/wan/vae.py
Normal file
567
comfy/ldm/wan/vae.py
Normal file
@@ -0,0 +1,567 @@
|
||||
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class CausalConv3d(ops.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = ops.Conv2d(dim, dim, 1)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=1)
|
||||
x = self.optimized_attention(q, k, v)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
@@ -307,7 +307,6 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
@@ -327,6 +326,13 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
if isinstance(model, comfy.model_base.StableCascade_C):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
@@ -11,7 +12,13 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
||||
return convert_lora_wan_fun(sd)
|
||||
return sd
|
||||
|
||||
@@ -34,6 +34,10 @@ import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -56,6 +60,7 @@ class ModelType(Enum):
|
||||
FLOW = 6
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
@@ -86,6 +91,8 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.FLUX:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingFlux
|
||||
elif model_type == ModelType.IMG_TO_IMG:
|
||||
c = comfy.model_sampling.IMG_TO_IMG
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -106,7 +113,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
@@ -137,6 +144,7 @@ class BaseModel(torch.nn.Module):
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
|
||||
@@ -159,15 +167,16 @@ class BaseModel(torch.nn.Module):
|
||||
extra = extra.to(dtype)
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def process_timestep(self, timestep, **kwargs):
|
||||
return timestep
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
def is_adm(self):
|
||||
return self.adm_channels > 0
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
@@ -186,6 +195,11 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if concat_latent_image.shape[1:] != noise.shape[1:]:
|
||||
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if noise.ndim == 5:
|
||||
if concat_latent_image.shape[-3] < noise.shape[-3]:
|
||||
concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
|
||||
else:
|
||||
concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
|
||||
|
||||
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
||||
|
||||
@@ -214,6 +228,11 @@ class BaseModel(torch.nn.Module):
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
elif ck == "mask_inverted":
|
||||
cond_concat.append(torch.zeros_like(noise)[:, :1])
|
||||
if ck == "concat_image":
|
||||
if concat_latent_image is not None:
|
||||
cond_concat.append(concat_latent_image.to(device))
|
||||
else:
|
||||
cond_concat.append(torch.zeros_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
return data
|
||||
return None
|
||||
@@ -587,6 +606,19 @@ class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
else:
|
||||
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
|
||||
|
||||
class Lotus(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
device = kwargs["device"]
|
||||
task_emb = torch.tensor([1, 0]).float().to(device)
|
||||
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0)
|
||||
out['y'] = comfy.conds.CONDRegular(task_emb)
|
||||
return out
|
||||
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
@@ -846,17 +878,26 @@ class LTXV(BaseModel):
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guiding_latent = kwargs.get("guiding_latent", None)
|
||||
if guiding_latent is not None:
|
||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
||||
|
||||
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
|
||||
if guiding_latent_noise_scale is not None:
|
||||
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
|
||||
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
if denoise_mask is None:
|
||||
return timestep
|
||||
return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
@@ -876,8 +917,32 @@ class HunyuanVideo(BaseModel):
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
guiding_frame_index = kwargs.get("guiding_frame_index", None)
|
||||
if guiding_frame_index is not None:
|
||||
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class HunyuanVideoI2V(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("concat_image", "mask_inverted")
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||
|
||||
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("concat_image",)
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||
|
||||
class CosmosVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
|
||||
@@ -904,3 +969,108 @@ class CosmosVideo(BaseModel):
|
||||
latent_image = latent_image + noise
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape_image = list(noise.shape)
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], 16):
|
||||
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
if not self.image_to_video or extra_channels == image.shape[1]:
|
||||
return image
|
||||
|
||||
if image.shape[1] > (extra_channels - 4):
|
||||
image = image[:, :(extra_channels - 4)]
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :4]
|
||||
else:
|
||||
if mask.shape[1] != 4:
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((mask, image), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guidance = kwargs.get("guidance", 5.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
||||
if conditioning_llama3 is not None:
|
||||
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
||||
return out
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
import comfy.utils
|
||||
@@ -33,7 +34,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||
return None
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix):
|
||||
def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||
@@ -136,7 +137,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan_video"
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = [1, 2, 2]
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
@@ -153,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
@@ -210,6 +211,8 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
if metadata is not None and "config" in metadata:
|
||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||
return dit_config
|
||||
|
||||
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||
@@ -239,7 +242,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["micro_condition"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: # Cosmos
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos"
|
||||
dit_config["max_img_h"] = 240
|
||||
@@ -284,6 +287,76 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["dim"] = 2304
|
||||
dit_config["cap_feat_dim"] = 2304
|
||||
dit_config["n_layers"] = 26
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
||||
dit_config["dim"] = dim
|
||||
dit_config["num_heads"] = dim // 128
|
||||
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["patch_size"] = (1, 2, 2)
|
||||
dit_config["freq_dim"] = 256
|
||||
dit_config["window_size"] = (-1, -1)
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["cross_attn_norm"] = True
|
||||
dit_config["eps"] = 1e-6
|
||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
else:
|
||||
dit_config["model_type"] = "t2v"
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2"
|
||||
dit_config["in_channels"] = in_shape[1]
|
||||
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["hidden_size"] = in_shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hidream"
|
||||
dit_config["attention_head_dim"] = 128
|
||||
dit_config["axes_dims_rope"] = [64, 32, 32]
|
||||
dit_config["caption_channels"] = [4096, 4096]
|
||||
dit_config["max_resolution"] = [128, 128]
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
|
||||
dit_config["num_attention_heads"] = 20
|
||||
dit_config["num_routed_experts"] = 4
|
||||
dit_config["num_activated_experts"] = 2
|
||||
dit_config["num_layers"] = 16
|
||||
dit_config["num_single_layers"] = 32
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["text_emb_dim"] = 2048
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -418,8 +491,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
logging.error("no match {}".format(unet_config))
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
|
||||
if unet_config is None:
|
||||
return None
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
@@ -432,6 +505,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
model_config.optimizations["fp8"] = False
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
return model_config
|
||||
|
||||
@@ -624,8 +701,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
|
||||
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||
supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
@@ -46,11 +46,39 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
float8_types.append(torch.float8_e4m3fn)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e4m3fnuz)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e5m2)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e5m2fnuz)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e8m0fnu)
|
||||
except:
|
||||
pass
|
||||
return float8_types
|
||||
|
||||
FLOAT8_TYPES = get_supported_float8_types()
|
||||
|
||||
xpu_available = False
|
||||
torch_version = ""
|
||||
try:
|
||||
torch_version = torch.version.__version__
|
||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||
temp = torch_version.split(".")
|
||||
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
||||
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -93,6 +121,13 @@ try:
|
||||
except:
|
||||
npu_available = False
|
||||
|
||||
try:
|
||||
import torch_mlu # noqa: F401
|
||||
_ = torch.mlu.device_count()
|
||||
mlu_available = torch.mlu.is_available()
|
||||
except:
|
||||
mlu_available = False
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
@@ -110,6 +145,12 @@ def is_ascend_npu():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_mlu():
|
||||
global mlu_available
|
||||
if mlu_available:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@@ -125,6 +166,8 @@ def get_torch_device():
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
elif is_ascend_npu():
|
||||
return torch.device("npu", torch.npu.current_device())
|
||||
elif is_mlu():
|
||||
return torch.device("mlu", torch.mlu.current_device())
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
@@ -151,6 +194,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
_, mem_total_npu = torch.npu.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_npu
|
||||
elif is_mlu():
|
||||
stats = torch.mlu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_mlu
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@@ -163,12 +212,21 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
else:
|
||||
return mem_total
|
||||
|
||||
def mac_version():
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch_version))
|
||||
mac_ver = mac_version()
|
||||
if mac_ver is not None:
|
||||
logging.info("Mac Version {}".format(mac_ver))
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -218,7 +276,7 @@ def is_amd():
|
||||
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||
if is_nvidia():
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.1
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
@@ -227,22 +285,45 @@ if args.use_pytorch_cross_attention:
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
if int(torch_version[0]) >= 2:
|
||||
if torch_version_numeric[0] >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if is_intel_xpu() or is_ascend_npu():
|
||||
if is_intel_xpu() or is_ascend_npu() or is_mlu():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
|
||||
|
||||
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
||||
try:
|
||||
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
|
||||
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
||||
logging.info("Enabled fp16 accumulation.")
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
except:
|
||||
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
||||
@@ -256,15 +337,10 @@ elif args.highvram or args.gpu_only:
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
|
||||
FORCE_FP32 = False
|
||||
FORCE_FP16 = False
|
||||
if args.force_fp32:
|
||||
logging.info("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
|
||||
if args.force_fp16:
|
||||
logging.info("Forcing FP16.")
|
||||
FORCE_FP16 = True
|
||||
|
||||
if lowvram_available:
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
@@ -297,6 +373,8 @@ def get_torch_device_name(device):
|
||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||
elif is_ascend_npu():
|
||||
return "{} {}".format(device, torch.npu.get_device_name(device))
|
||||
elif is_mlu():
|
||||
return "{} {}".format(device, torch.mlu.get_device_name(device))
|
||||
else:
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
@@ -538,7 +616,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
@@ -632,7 +710,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
def maximum_vram_for_weights(device=None):
|
||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
|
||||
if model_params < 0:
|
||||
model_params = 1000000000000000000000
|
||||
if args.fp32_unet:
|
||||
@@ -649,13 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
return torch.float8_e5m2
|
||||
|
||||
fp8_dtype = None
|
||||
try:
|
||||
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
if dtype in supported_dtypes:
|
||||
fp8_dtype = dtype
|
||||
break
|
||||
except:
|
||||
pass
|
||||
if weight_dtype in FLOAT8_TYPES:
|
||||
fp8_dtype = weight_dtype
|
||||
|
||||
if fp8_dtype is not None:
|
||||
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||
@@ -665,6 +738,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
if model_params * 2 > free_model_memory:
|
||||
return fp8_dtype
|
||||
|
||||
if PRIORITIZE_FP16 or weight_dtype == torch.float16:
|
||||
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
||||
if torch.float16 in supported_dtypes:
|
||||
@@ -697,6 +774,9 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||
if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and fp16_supported:
|
||||
return torch.float16
|
||||
@@ -743,6 +823,8 @@ def text_encoder_dtype(device=None):
|
||||
return torch.float8_e5m2
|
||||
elif args.fp16_text_enc:
|
||||
return torch.float16
|
||||
elif args.bf16_text_enc:
|
||||
return torch.bfloat16
|
||||
elif args.fp32_text_enc:
|
||||
return torch.float32
|
||||
|
||||
@@ -873,6 +955,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
def flash_attention_enabled():
|
||||
return args.use_flash_attention
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@@ -882,6 +967,8 @@ def xformers_enabled():
|
||||
return False
|
||||
if is_ascend_npu():
|
||||
return False
|
||||
if is_mlu():
|
||||
return False
|
||||
if directml_enabled:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
@@ -898,6 +985,11 @@ def pytorch_attention_enabled():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
return ENABLE_PYTORCH_ATTENTION
|
||||
|
||||
def pytorch_attention_enabled_vae():
|
||||
if is_amd():
|
||||
return False # enabling pytorch attention on AMD currently causes crash when doing high res
|
||||
return pytorch_attention_enabled()
|
||||
|
||||
def pytorch_attention_flash_attention():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
@@ -908,23 +1000,21 @@ def pytorch_attention_flash_attention():
|
||||
return True
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
if is_mlu():
|
||||
return True
|
||||
if is_amd():
|
||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||
return False
|
||||
|
||||
def mac_version():
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
|
||||
macos_version = mac_version()
|
||||
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
|
||||
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
|
||||
upcast = True
|
||||
|
||||
if upcast:
|
||||
return torch.float32
|
||||
return {torch.float16: torch.float32}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -954,6 +1044,13 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_npu, _ = torch.npu.mem_get_info(dev)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_npu + mem_free_torch
|
||||
elif is_mlu():
|
||||
stats = torch.mlu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_mlu + mem_free_torch
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@@ -990,21 +1087,26 @@ def is_device_mps(device):
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
def is_directml_enabled():
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
if device is not None:
|
||||
if is_device_cpu(device):
|
||||
return False
|
||||
|
||||
if FORCE_FP16:
|
||||
if args.force_fp16:
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
return False
|
||||
if is_directml_enabled():
|
||||
return True
|
||||
|
||||
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||
return True
|
||||
@@ -1018,6 +1120,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if is_mlu():
|
||||
return True
|
||||
|
||||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
@@ -1075,13 +1180,28 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||
if manual_cast:
|
||||
return True
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
|
||||
if is_mlu():
|
||||
if props.major > 3:
|
||||
return True
|
||||
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
if bf16_works or manual_cast:
|
||||
if bf16_works and manual_cast:
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
@@ -1100,11 +1220,11 @@ def supports_fp8_compute(device=None):
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -1117,6 +1237,8 @@ def soft_empty_cache(force=False):
|
||||
torch.xpu.empty_cache()
|
||||
elif is_ascend_npu():
|
||||
torch.npu.empty_cache()
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
@@ -96,8 +96,28 @@ def wipe_lowvram_weight(m):
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
|
||||
if hasattr(m, "weight_function"):
|
||||
m.weight_function = []
|
||||
|
||||
if hasattr(m, "bias_function"):
|
||||
m.bias_function = []
|
||||
|
||||
def move_weight_functions(m, device):
|
||||
if device is None:
|
||||
return 0
|
||||
|
||||
memory = 0
|
||||
if hasattr(m, "weight_function"):
|
||||
for f in m.weight_function:
|
||||
if hasattr(f, "move_to"):
|
||||
memory += f.move_to(device=device)
|
||||
|
||||
if hasattr(m, "bias_function"):
|
||||
for f in m.bias_function:
|
||||
if hasattr(f, "move_to"):
|
||||
memory += f.move_to(device=device)
|
||||
return memory
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, patches):
|
||||
@@ -192,11 +212,13 @@ class ModelPatcher:
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.force_cast_weights = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.parent = None
|
||||
|
||||
@@ -250,11 +272,14 @@ class ModelPatcher:
|
||||
n.patches_uuid = self.patches_uuid
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
# attachments
|
||||
n.attachments = {}
|
||||
for k in self.attachments:
|
||||
@@ -402,6 +427,16 @@ class ModelPatcher:
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def set_model_compute_dtype(self, dtype):
|
||||
self.add_object_patch("manual_cast_dtype", dtype)
|
||||
if dtype is not None:
|
||||
self.force_cast_weights = True
|
||||
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
||||
|
||||
def add_weight_wrapper(self, name, function):
|
||||
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
|
||||
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||
"""Retrieves a nested attribute from an object using dot notation considering
|
||||
object patches.
|
||||
@@ -566,6 +601,9 @@ class ModelPatcher:
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
@@ -573,34 +611,46 @@ class ModelPatcher:
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
cast_weight = self.force_cast_weights
|
||||
if lowvram_weight:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.weight_function = []
|
||||
m.bias_function = []
|
||||
|
||||
if weight_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
cast_weight = True
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
if m.comfy_cast_weights:
|
||||
wipe_lowvram_weight(m)
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
|
||||
if weight_key in self.weight_wrapper_patches:
|
||||
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
||||
|
||||
if bias_key in self.weight_wrapper_patches:
|
||||
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
|
||||
|
||||
mem_counter += move_weight_functions(m, device_to)
|
||||
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
@@ -662,6 +712,7 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
self.model.model_lowvram = False
|
||||
@@ -696,6 +747,7 @@ class ModelPatcher:
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
@@ -719,6 +771,10 @@ class ModelPatcher:
|
||||
move_weight = False
|
||||
break
|
||||
|
||||
if not hooks_unpatched:
|
||||
self.unpatch_hooks()
|
||||
hooks_unpatched = True
|
||||
|
||||
if bk.inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
@@ -728,15 +784,19 @@ class ModelPatcher:
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
m.to(device_to)
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight:
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
@@ -1034,7 +1094,6 @@ class ModelPatcher:
|
||||
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
if hooks is not None:
|
||||
model_sd_keys = list(self.model_state_dict().keys())
|
||||
memory_counter = None
|
||||
@@ -1045,12 +1104,16 @@ class ModelPatcher:
|
||||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
model_sd_keys_set = set(model_sd_keys)
|
||||
for key in cached_weights:
|
||||
if key not in model_sd_keys:
|
||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
model_sd_keys_set.remove(key)
|
||||
self.unpatch_hooks(model_sd_keys_set)
|
||||
else:
|
||||
self.unpatch_hooks()
|
||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||
original_weights = None
|
||||
if len(relevant_patches) > 0:
|
||||
@@ -1061,6 +1124,8 @@ class ModelPatcher:
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
else:
|
||||
self.unpatch_hooks()
|
||||
self.current_hooks = hooks
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
@@ -1117,17 +1182,23 @@ class ModelPatcher:
|
||||
del out_weight
|
||||
del weight
|
||||
|
||||
def unpatch_hooks(self) -> None:
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
with self.use_ejected():
|
||||
if len(self.hook_backup) == 0:
|
||||
self.current_hooks = None
|
||||
return
|
||||
keys = list(self.hook_backup.keys())
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
if whitelist_keys_set:
|
||||
for k in keys:
|
||||
if k in whitelist_keys_set:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
self.hook_backup.pop(k)
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
|
||||
def clean_hooks(self):
|
||||
self.unpatch_hooks()
|
||||
|
||||
@@ -31,6 +31,7 @@ class EPS:
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||
else:
|
||||
@@ -61,11 +62,22 @@ class CONST:
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class X0(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
return model_output
|
||||
|
||||
class IMG_TO_IMG(X0):
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__()
|
||||
|
||||
96
comfy/ops.py
96
comfy/ops.py
@@ -17,9 +17,11 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
@@ -38,21 +40,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = s.bias_function is not None
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
bias = s.bias_function(bias)
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = s.weight_function is not None
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
weight = s.weight_function(weight)
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
return weight, bias
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
weight_function = []
|
||||
bias_function = []
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
@@ -64,7 +68,7 @@ class disable_weight_init:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -78,7 +82,7 @@ class disable_weight_init:
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -92,7 +96,7 @@ class disable_weight_init:
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -106,7 +110,7 @@ class disable_weight_init:
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -120,12 +124,11 @@ class disable_weight_init:
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@@ -139,7 +142,26 @@ class disable_weight_init:
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
else:
|
||||
weight = None
|
||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -160,7 +182,7 @@ class disable_weight_init:
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -181,7 +203,7 @@ class disable_weight_init:
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -199,7 +221,7 @@ class disable_weight_init:
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
if "out_dtype" in kwargs:
|
||||
@@ -241,6 +263,9 @@ class manual_cast(disable_weight_init):
|
||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class RMSNorm(disable_weight_init.RMSNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
@@ -307,6 +332,7 @@ class fp8_ops(manual_cast):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -354,14 +380,46 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
CUBLAS_IS_AVAILABLE = False
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
CUBLAS_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if CUBLAS_IS_AVAILABLE:
|
||||
class cublas_ops(disable_weight_init):
|
||||
class Linear(CublasLinear, disable_weight_init.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
return super().forward(input)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
if (
|
||||
fp8_compute and
|
||||
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
||||
not disable_fast_fp8
|
||||
):
|
||||
return fp8_ops
|
||||
|
||||
if (
|
||||
PerformanceFeature.CublasOps in args.fast and
|
||||
CUBLAS_IS_AVAILABLE and
|
||||
weight_dtype == torch.float16 and
|
||||
(compute_dtype == torch.float16 or compute_dtype is None)
|
||||
):
|
||||
logging.info("Using cublas ops")
|
||||
return cublas_ops
|
||||
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option
|
||||
|
||||
class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
PREPARE_SAMPLING = "prepare_sampling"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
|
||||
54
comfy/rmsnorm.py
Normal file
54
comfy/rmsnorm.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
if RMSNorm is None:
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=None,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
# mypy error: incompatible types in assignment
|
||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
@@ -106,6 +106,13 @@ def cleanup_additional_models(models):
|
||||
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
|
||||
@@ -19,6 +19,12 @@ import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
|
||||
def add_area_dims(area, num_dims):
|
||||
while (len(area) // 2) < num_dims:
|
||||
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
|
||||
return area
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
dims = tuple(x_in.shape[2:])
|
||||
area = None
|
||||
@@ -34,6 +40,10 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
return None
|
||||
if 'area' in conds:
|
||||
area = list(conds['area'])
|
||||
area = add_area_dims(area, len(dims))
|
||||
if (len(area) // 2) > len(dims):
|
||||
area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
|
||||
|
||||
if 'strength' in conds:
|
||||
strength = conds['strength']
|
||||
|
||||
@@ -50,7 +60,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert(mask.shape[1:] == x_in.shape[2:])
|
||||
assert (mask.shape[1:] == x_in.shape[2:])
|
||||
|
||||
mask = mask[:input_x.shape[0]]
|
||||
if area is not None:
|
||||
@@ -64,16 +74,17 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
mult = mask * strength
|
||||
|
||||
if 'mask' not in conds and area is not None:
|
||||
rr = 8
|
||||
fuzz = 8
|
||||
for i in range(len(dims)):
|
||||
rr = min(fuzz, mult.shape[2 + i] // 4)
|
||||
if area[len(dims) + i] != 0:
|
||||
for t in range(rr):
|
||||
m = mult.narrow(i + 2, t, 1)
|
||||
m *= ((1.0/rr) * (t + 1))
|
||||
m *= ((1.0 / rr) * (t + 1))
|
||||
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
|
||||
for t in range(rr):
|
||||
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
|
||||
m *= ((1.0/rr) * (t + 1))
|
||||
m *= ((1.0 / rr) * (t + 1))
|
||||
|
||||
conditioning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
@@ -548,25 +559,37 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
|
||||
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
|
||||
def create_cond_with_same_area_if_none(conds, c):
|
||||
if 'area' not in c:
|
||||
return
|
||||
|
||||
def area_inside(a, area_cmp):
|
||||
a = add_area_dims(a, len(area_cmp) // 2)
|
||||
area_cmp = add_area_dims(area_cmp, len(a) // 2)
|
||||
|
||||
a_l = len(a) // 2
|
||||
area_cmp_l = len(area_cmp) // 2
|
||||
for i in range(min(a_l, area_cmp_l)):
|
||||
if a[a_l + i] < area_cmp[area_cmp_l + i]:
|
||||
return False
|
||||
for i in range(min(a_l, area_cmp_l)):
|
||||
if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
|
||||
return False
|
||||
return True
|
||||
|
||||
c_area = c['area']
|
||||
smallest = None
|
||||
for x in conds:
|
||||
if 'area' in x:
|
||||
a = x['area']
|
||||
if c_area[2] >= a[2] and c_area[3] >= a[3]:
|
||||
if a[0] + a[2] >= c_area[0] + c_area[2]:
|
||||
if a[1] + a[3] >= c_area[1] + c_area[3]:
|
||||
if smallest is None:
|
||||
smallest = x
|
||||
elif 'area' not in smallest:
|
||||
smallest = x
|
||||
else:
|
||||
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
|
||||
smallest = x
|
||||
if area_inside(c_area, a):
|
||||
if smallest is None:
|
||||
smallest = x
|
||||
elif 'area' not in smallest:
|
||||
smallest = x
|
||||
else:
|
||||
if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
|
||||
smallest = x
|
||||
else:
|
||||
if smallest is None:
|
||||
smallest = x
|
||||
@@ -686,7 +709,8 @@ class Sampler:
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"]
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
||||
111
comfy/sd.py
111
comfy/sd.py
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import torch
|
||||
from enum import Enum
|
||||
import logging
|
||||
@@ -12,6 +13,8 @@ from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import yaml
|
||||
import math
|
||||
|
||||
@@ -36,6 +39,9 @@ import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -131,8 +137,8 @@ class CLIP:
|
||||
def clip_layer(self, layer_idx):
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
def tokenize(self, text, return_word_ids=False, **kwargs):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||
if self.apply_hooks_to_conds:
|
||||
@@ -246,7 +252,7 @@ class CLIP:
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
@@ -260,6 +266,7 @@ class VAE:
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
@@ -332,6 +339,7 @@ class VAE:
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.disable_offload = True
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||
@@ -354,7 +362,12 @@ class VAE:
|
||||
version = 0
|
||||
elif tensor_conv1.shape[0] == 1024:
|
||||
version = 1
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
|
||||
if "encoder.down_blocks.1.conv.conv.bias" in sd:
|
||||
version = 2
|
||||
vae_config = None
|
||||
if metadata is not None and "config" in metadata:
|
||||
vae_config = json.loads(metadata["config"]).get("vae", None)
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
@@ -391,6 +404,29 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.middle.0.residual.0.gamma" in sd:
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
self.latent_dim = 1
|
||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
||||
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
|
||||
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
|
||||
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
|
||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -419,6 +455,10 @@ class VAE:
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
@@ -473,18 +513,19 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
@@ -504,8 +545,9 @@ class VAE:
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
dims = samples.ndim - 2
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
@@ -532,13 +574,14 @@ class VAE:
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
@@ -564,6 +607,7 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
@@ -571,7 +615,7 @@ class VAE:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
@@ -657,6 +701,8 @@ class CLIPType(Enum):
|
||||
HUNYUAN_VIDEO = 9
|
||||
PIXART = 10
|
||||
COSMOS = 11
|
||||
LUMINA2 = 12
|
||||
WAN = 13
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -675,6 +721,7 @@ class TEModel(Enum):
|
||||
T5_BASE = 6
|
||||
LLAMA3_8 = 7
|
||||
T5_XXL_OLD = 8
|
||||
GEMMA_2_2B = 9
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -693,6 +740,8 @@ def detect_te_model(sd):
|
||||
return TEModel.T5_XXL_OLD
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
@@ -730,6 +779,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
if "text_projection" in clip_data[i]:
|
||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||
|
||||
tokenizer_data = {}
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
@@ -757,6 +807,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.PIXART:
|
||||
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
|
||||
elif clip_type == CLIPType.WAN:
|
||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
@@ -769,6 +823,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.T5_BASE:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
else:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
@@ -796,9 +854,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif len(clip_data) == 4:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
|
||||
parameters = 0
|
||||
tokenizer_data = {}
|
||||
for c in clip_data:
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
@@ -845,13 +905,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
return (model, clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return out
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
@@ -863,19 +923,24 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
return None
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||
if diffusion_model is None:
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
@@ -892,7 +957,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae = VAE(sd=vae_sd)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
@@ -966,11 +1031,11 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
if dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
|
||||
@@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
"hidden"
|
||||
"hidden",
|
||||
"all"
|
||||
]
|
||||
def __init__(self, device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
@@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
if "model_name" not in model_options:
|
||||
model_options = {**model_options, "model_name": "clip_l"}
|
||||
|
||||
if isinstance(textmodel_json_config, dict):
|
||||
config = textmodel_json_config
|
||||
@@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
|
||||
for k, v in te_model_options.items():
|
||||
config[k] = v
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
|
||||
@@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
if self.layer == "all":
|
||||
pass
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
@@ -158,71 +167,98 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer_idx = self.options_default[1]
|
||||
self.return_projected_pooled = self.options_default[2]
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||||
embedding_weights = []
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
embeds_out = []
|
||||
attention_masks = []
|
||||
num_tokens = []
|
||||
|
||||
for x in tokens:
|
||||
attention_mask = []
|
||||
tokens_temp = []
|
||||
other_embeds = []
|
||||
eos = False
|
||||
index = 0
|
||||
for y in x:
|
||||
if isinstance(y, numbers.Integral):
|
||||
tokens_temp += [int(y)]
|
||||
else:
|
||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||
embedding_weights += [y]
|
||||
tokens_temp += [next_new_token]
|
||||
next_new_token += 1
|
||||
if eos:
|
||||
attention_mask.append(0)
|
||||
else:
|
||||
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
|
||||
while len(tokens_temp) < len(x):
|
||||
tokens_temp += [self.special_tokens["pad"]]
|
||||
out_tokens += [tokens_temp]
|
||||
attention_mask.append(1)
|
||||
token = int(y)
|
||||
tokens_temp += [token]
|
||||
if not eos and token == cmp_token:
|
||||
if end_token is None:
|
||||
attention_mask[-1] = 0
|
||||
eos = True
|
||||
else:
|
||||
other_embeds.append((index, y))
|
||||
index += 1
|
||||
|
||||
n = token_dict_size
|
||||
if len(embedding_weights) > 0:
|
||||
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight
|
||||
for x in embedding_weights:
|
||||
new_embedding.weight[n] = x
|
||||
n += 1
|
||||
self.transformer.set_input_embeddings(new_embedding)
|
||||
tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
|
||||
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
||||
index = 0
|
||||
pad_extra = 0
|
||||
for o in other_embeds:
|
||||
emb = o[1]
|
||||
if torch.is_tensor(emb):
|
||||
emb = {"type": "embedding", "data": emb}
|
||||
|
||||
processed_tokens = []
|
||||
for x in out_tokens:
|
||||
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
|
||||
emb_type = emb.get("type", None)
|
||||
if emb_type == "embedding":
|
||||
emb = emb.get("data", None)
|
||||
else:
|
||||
if hasattr(self.transformer, "preprocess_embed"):
|
||||
emb = self.transformer.preprocess_embed(emb, device=device)
|
||||
else:
|
||||
emb = None
|
||||
|
||||
return processed_tokens
|
||||
if emb is None:
|
||||
index += -1
|
||||
continue
|
||||
|
||||
ind = index + o[0]
|
||||
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
|
||||
emb_shape = emb.shape[1]
|
||||
if emb.shape[-1] == tokens_embed.shape[-1]:
|
||||
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
||||
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
||||
index += emb_shape - 1
|
||||
else:
|
||||
index += -1
|
||||
pad_extra += emb_shape
|
||||
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
|
||||
|
||||
if pad_extra > 0:
|
||||
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
|
||||
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
|
||||
attention_mask = attention_mask + [0] * pad_extra
|
||||
|
||||
embeds_out.append(tokens_embed)
|
||||
attention_masks.append(attention_mask)
|
||||
num_tokens.append(sum(attention_mask))
|
||||
|
||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == cmp_token:
|
||||
if end_token is None:
|
||||
attention_mask[x, y] = 0
|
||||
break
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
|
||||
|
||||
attention_mask_model = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
if self.layer == "all":
|
||||
intermediate_output = "all"
|
||||
else:
|
||||
intermediate_output = self.layer_idx
|
||||
|
||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0].float()
|
||||
@@ -421,11 +457,11 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||
self.min_length = min_length
|
||||
self.end_token = None
|
||||
|
||||
@@ -482,7 +518,7 @@ class SDTokenizer:
|
||||
return (embed, leftover)
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
'''
|
||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||
@@ -585,13 +621,18 @@ class SDTokenizer:
|
||||
return {}
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
||||
if name is not None:
|
||||
self.clip_name = name
|
||||
self.clip = "{}".format(self.clip_name)
|
||||
else:
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
@@ -600,7 +641,7 @@ class SD1Tokenizer:
|
||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
return getattr(self, self.clip).state_dict()
|
||||
|
||||
class SD1CheckpointClipModel(SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
@@ -618,6 +659,7 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||
model_options = {**model_options, "model_name": self.clip}
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
|
||||
@@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
model_options = {**model_options, "model_name": "clip_g"}
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
@@ -17,16 +18,15 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
@@ -41,8 +41,7 @@ class SDXLTokenizer:
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set([dtype])
|
||||
|
||||
@@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
|
||||
|
||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
model_options = {**model_options, "model_name": "clip_g"}
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -504,6 +506,22 @@ class SDXL_instructpix2pix(SDXL):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
|
||||
class LotusD(SD20):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"use_temporal_attention": False,
|
||||
"adm_in_channels": 4,
|
||||
"in_channels": 4,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_classes": 'sequential'
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Lotus(self, device=device)
|
||||
|
||||
class SD3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"in_channels": 16,
|
||||
@@ -760,7 +778,7 @@ class LTXV(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.LTXV
|
||||
|
||||
memory_usage_factor = 2.7
|
||||
memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
@@ -824,6 +842,26 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||
|
||||
class HunyuanVideoI2V(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"in_channels": 33,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideoI2V(self, device=device)
|
||||
return out
|
||||
|
||||
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"in_channels": 32,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
|
||||
return out
|
||||
|
||||
class CosmosT2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cosmos",
|
||||
@@ -865,6 +903,158 @@ class CosmosI2V(CosmosT2V):
|
||||
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V]
|
||||
class Lumina2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "lumina2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 6.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Lumina2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 8.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||
|
||||
class WAN21_I2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "i2v",
|
||||
"in_dim": 36,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_FunControl2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "i2v",
|
||||
"in_dim": 48,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 3.5
|
||||
|
||||
clip_vision_prefix = "conditioner.main_image_encoder.model."
|
||||
vae_key_prefix = ["vae."]
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Hunyuan3Dv2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
"depth": 8,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
class HiDream(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hidream",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
}
|
||||
|
||||
# memory_usage_factor = 1.2 # TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HiDream(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel):
|
||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
|
||||
|
||||
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
||||
@@ -93,8 +93,11 @@ class BertEmbeddings(torch.nn.Module):
|
||||
|
||||
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, token_type_ids=None, dtype=None):
|
||||
x = self.word_embeddings(input_tokens, out_dtype=dtype)
|
||||
def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.word_embeddings(input_tokens, out_dtype=dtype)
|
||||
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
|
||||
if token_type_ids is not None:
|
||||
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
|
||||
@@ -113,12 +116,12 @@ class BertModel_(torch.nn.Module):
|
||||
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
x, i = self.encoder(x, mask, intermediate_output)
|
||||
return x, i
|
||||
|
||||
@@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel):
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
@@ -9,16 +9,15 @@ import os
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class FluxTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
@@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_t5])
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
150
comfy/text_encoders/hidream.py
Normal file
150
comfy/text_encoders/hidream.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from . import hunyuan_video
|
||||
from . import sd3_clip
|
||||
from comfy import sd1_clip
|
||||
from comfy import sdxl_clip
|
||||
import comfy.model_management
|
||||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
class HiDreamTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data)
|
||||
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
class HiDreamTEModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
||||
if clip_g:
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_g = None
|
||||
|
||||
if t5:
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
|
||||
self.dtypes.add(dtype_t5)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
if llama:
|
||||
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||
if "vocab_size" not in model_options:
|
||||
model_options["vocab_size"] = 128256
|
||||
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
|
||||
self.dtypes.add(dtype_llama)
|
||||
else:
|
||||
self.llama = None
|
||||
|
||||
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
|
||||
|
||||
def set_clip_options(self, options):
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.set_clip_options(options)
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.set_clip_options(options)
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.set_clip_options(options)
|
||||
if self.llama is not None:
|
||||
self.llama.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.reset_clip_options()
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.reset_clip_options()
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.reset_clip_options()
|
||||
if self.llama is not None:
|
||||
self.llama.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
|
||||
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||
lg_out = None
|
||||
pooled = None
|
||||
extra = {}
|
||||
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
if self.clip_l is not None:
|
||||
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
else:
|
||||
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if self.clip_g is not None:
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
else:
|
||||
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
t5_out, t5_pooled = t5_output[:2]
|
||||
|
||||
if self.llama is not None:
|
||||
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
ll_out, ll_pooled = ll_output[:2]
|
||||
ll_out = ll_out[:, 1:]
|
||||
|
||||
if t5_out is None:
|
||||
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if ll_out is None:
|
||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
extra["conditioning_llama3"] = ll_out
|
||||
return t5_out, pooled, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
return self.clip_g.load_sd(sd)
|
||||
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
else:
|
||||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
||||
class HiDreamTEModel_(HiDreamTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HiDreamTEModel_
|
||||
@@ -4,6 +4,7 @@ import comfy.text_encoders.llama
|
||||
from transformers import LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
import numbers
|
||||
|
||||
|
||||
def llama_detect(state_dict, prefix=""):
|
||||
@@ -20,33 +21,49 @@ def llama_detect(state_dict, prefix=""):
|
||||
|
||||
|
||||
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
textmodel_json_config = {}
|
||||
vocab_size = model_options.get("vocab_size", None)
|
||||
if vocab_size is not None:
|
||||
textmodel_json_config["vocab_size"] = vocab_size
|
||||
|
||||
model_options = {**model_options, "model_name": "llama"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class HunyuanVideoTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
|
||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
llama_text = "{}{}".format(self.llama_template, text)
|
||||
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||
embed_count = 0
|
||||
for r in llama_text_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 128257:
|
||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
|
||||
embed_count += 1
|
||||
out["llama"] = llama_text_tokens
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
@@ -60,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
|
||||
@@ -80,20 +96,51 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
|
||||
template_end = 0
|
||||
for i, v in enumerate(token_weight_pairs_llama[0]):
|
||||
if v[0] == 128007: # <|end_header_id|>
|
||||
template_end = i
|
||||
extra_template_end = 0
|
||||
extra_sizes = 0
|
||||
user_end = 9999999999999
|
||||
images = []
|
||||
|
||||
tok_pairs = token_weight_pairs_llama[0]
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 128006:
|
||||
if tok_pairs[i + 1][0] == 882:
|
||||
if tok_pairs[i + 2][0] == 128007:
|
||||
template_end = i + 2
|
||||
user_end = -1
|
||||
if elem == 128009 and user_end == -1:
|
||||
user_end = i + 1
|
||||
else:
|
||||
if elem.get("original_type") == "image":
|
||||
elem_size = elem.get("data").shape[0]
|
||||
if template_end > 0:
|
||||
if user_end == -1:
|
||||
extra_template_end += elem_size - 1
|
||||
else:
|
||||
image_start = i + extra_sizes
|
||||
image_end = i + elem_size + extra_sizes
|
||||
images.append((image_start, image_end, elem.get("image_interleave", 1)))
|
||||
extra_sizes += elem_size - 1
|
||||
|
||||
if llama_out.shape[1] > (template_end + 2):
|
||||
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
|
||||
if tok_pairs[template_end + 1][0] == 271:
|
||||
template_end += 2
|
||||
llama_out = llama_out[:, template_end:]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
|
||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
if len(images) > 0:
|
||||
out = []
|
||||
for i in images:
|
||||
out.append(llama_out[:, i[0]: i[1]: i[2]])
|
||||
llama_output = torch.cat(out + [llama_output], dim=1)
|
||||
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return llama_out, l_pooled, llama_extra_out
|
||||
return llama_output, l_pooled, llama_extra_out
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
|
||||
@@ -9,24 +9,26 @@ import torch
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||
model_options = {**model_options, "model_name": "hydit_clip"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class MT5XLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||
model_options = {**model_options, "model_name": "mt5xl"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
@@ -35,9 +37,9 @@ class HyditTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
|
||||
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
@@ -21,15 +20,41 @@ class Llama2Config:
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
vocab_size: int = 256000
|
||||
hidden_size: int = 2304
|
||||
intermediate_size: int = 9216
|
||||
num_hidden_layers: int = 26
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 4
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
transformer_type: str = "gemma2"
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
self.add = add
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
w = self.weight
|
||||
if self.add:
|
||||
w = w + 1.0
|
||||
|
||||
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@@ -68,13 +93,15 @@ class Attention(nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
self.head_dim = config.head_dim
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
|
||||
ops = ops or nn
|
||||
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -84,7 +111,6 @@ class Attention(nn.Module):
|
||||
optimized_attention=None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states)
|
||||
xk = self.k_proj(hidden_states)
|
||||
xv = self.v_proj(hidden_states)
|
||||
@@ -108,9 +134,13 @@ class MLP(nn.Module):
|
||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
if config.mlp_activation == "silu":
|
||||
self.activation = torch.nn.functional.silu
|
||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||
self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh")
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
@@ -146,6 +176,45 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlockGemma2(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
# MLP
|
||||
residual = x
|
||||
x = self.pre_feedforward_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = self.post_feedforward_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
@@ -158,17 +227,30 @@ class Llama2_(nn.Module):
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.normalize_in = True
|
||||
else:
|
||||
transformer = TransformerBlock
|
||||
self.normalize_in = False
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
|
||||
transformer(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
x.shape[1],
|
||||
self.config.rope_theta,
|
||||
device=x.device)
|
||||
@@ -186,11 +268,17 @@ class Llama2_(nn.Module):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
all_intermediate = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
if intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
x = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
@@ -201,21 +289,18 @@ class Llama2_(nn.Module):
|
||||
intermediate = x.clone()
|
||||
|
||||
x = self.norm(x)
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Llama2(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Llama2Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class BaseLlama:
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
@@ -224,3 +309,23 @@ class Llama2(torch.nn.Module):
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
|
||||
class Llama2(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Llama2Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma2_2B_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -1,30 +1,27 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LongClipModel_(sd1_clip.SDClipModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||
|
||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
|
||||
|
||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||
|
||||
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||
if w is None:
|
||||
w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
|
||||
else:
|
||||
model_name = "clip_g"
|
||||
|
||||
if w is None:
|
||||
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None and w.shape[0] == 248:
|
||||
if w is not None:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
model_name = "clip_g"
|
||||
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
model_name = "clip_l"
|
||||
else:
|
||||
model_name = "clip_l"
|
||||
|
||||
if w is not None:
|
||||
tokenizer_data = tokenizer_data.copy()
|
||||
model_options = model_options.copy()
|
||||
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||
model_options["clip_l_class"] = LongClipModel_
|
||||
model_config = model_options.get("model_config", {})
|
||||
model_config["max_position_embeddings"] = w.shape[0]
|
||||
model_options["{}_model_config".format(model_name)] = model_config
|
||||
tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
|
||||
return tokenizer_data, model_options
|
||||
|
||||
@@ -6,7 +6,7 @@ import comfy.text_encoders.genmo
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
|
||||
|
||||
|
||||
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
39
comfy/text_encoders/lumina2.py
Normal file
39
comfy/text_encoders/lumina2.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
|
||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
|
||||
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||
|
||||
|
||||
class Gemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return LuminaTEModel_
|
||||
@@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
||||
|
||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
||||
@@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel):
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data)
|
||||
|
||||
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
||||
@@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
||||
@@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
|
||||
model_options = {**model_options, "model_name": "t5xxl"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
@@ -31,19 +32,18 @@ def t5_xxl_detect(state_dict, prefix=""):
|
||||
return out
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
@@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
import torch
|
||||
|
||||
class SPieceTokenizer:
|
||||
add_eos = True
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path):
|
||||
return SPieceTokenizer(path)
|
||||
def from_pretrained(path, **kwargs):
|
||||
return SPieceTokenizer(path, **kwargs)
|
||||
|
||||
def __init__(self, tokenizer_path):
|
||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
import sentencepiece
|
||||
if torch.is_tensor(tokenizer_path):
|
||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||
|
||||
if isinstance(tokenizer_path, bytes):
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||
else:
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||
|
||||
def get_vocab(self):
|
||||
out = {}
|
||||
|
||||
@@ -203,7 +203,7 @@ class T5Stack(torch.nn.Module):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
intermediate = None
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
|
||||
@@ -239,8 +239,11 @@ class T5(torch.nn.Module):
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.shared = embeddings
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||
def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
|
||||
if input_ids is None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
x = torch.nan_to_num(x) #Fix for fp8 T5 base
|
||||
return self.encoder(x, *args, **kwargs)
|
||||
return self.encoder(x, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
22
comfy/text_encoders/umt5_config_xxl.json
Normal file
22
comfy/text_encoders/umt5_config_xxl.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "umt5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 256384
|
||||
}
|
||||
37
comfy/text_encoders/wan.py
Normal file
37
comfy/text_encoders/wan.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class UMT5XXlModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
||||
|
||||
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
|
||||
class WanT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer)
|
||||
|
||||
class WanT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class WanTEModel(WanT5Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return WanTEModel
|
||||
@@ -46,19 +46,25 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
metadata = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
sd[k] = f.get_tensor(k)
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
if "HeaderTooLarge" in message:
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
|
||||
if "MetadataIncompleteBuffer" in message:
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||
raise e
|
||||
else:
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
@@ -77,7 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
sd = pl_sd
|
||||
else:
|
||||
sd = pl_sd
|
||||
return sd
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
if metadata is not None:
|
||||
|
||||
@@ -316,3 +316,156 @@ class LRUCache(BasicCache):
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
|
||||
class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
A cache implementation that tracks dependencies between nodes and manages
|
||||
their execution and caching accordingly. It extends the BasicCache class.
|
||||
Nodes are removed from this cache once all of their descendants have been
|
||||
executed.
|
||||
"""
|
||||
|
||||
def __init__(self, key_class):
|
||||
"""
|
||||
Initialize the DependencyAwareCache.
|
||||
|
||||
Args:
|
||||
key_class: The class used for generating cache keys.
|
||||
"""
|
||||
super().__init__(key_class)
|
||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to initialize the cache for.
|
||||
is_changed_cache: Flag indicating if the cache has changed.
|
||||
"""
|
||||
# Clear all existing cache data
|
||||
self.cache.clear()
|
||||
self.subcaches.clear()
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
|
||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||
"""
|
||||
Build the dependency graph for all nodes.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to build the graph for.
|
||||
"""
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
for node_id in node_ids:
|
||||
self.descendants[node_id] = set()
|
||||
self.ancestors[node_id] = set()
|
||||
|
||||
for node_id in node_ids:
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
for input_data in inputs.values():
|
||||
if is_link(input_data): # Check if the input is a link to another node
|
||||
ancestor_id = input_data[0]
|
||||
self.descendants[ancestor_id].add(node_id)
|
||||
self.ancestors[node_id].add(ancestor_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
"""
|
||||
Mark a node as executed and store its value in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to store.
|
||||
value: The value to store for the node.
|
||||
"""
|
||||
self._set_immediate(node_id, value)
|
||||
self.executed_nodes.add(node_id)
|
||||
self._cleanup_ancestors(node_id)
|
||||
|
||||
def get(self, node_id):
|
||||
"""
|
||||
Retrieve the cached value for a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to retrieve.
|
||||
|
||||
Returns:
|
||||
The cached value for the node.
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the parent node.
|
||||
children_ids: List of child node IDs to associate with the parent node.
|
||||
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
return subcache
|
||||
|
||||
def _cleanup_ancestors(self, node_id):
|
||||
"""
|
||||
Check if ancestors of a node can be removed from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node whose ancestors are to be checked.
|
||||
"""
|
||||
for ancestor_id in self.ancestors.get(node_id, []):
|
||||
if ancestor_id in self.executed_nodes:
|
||||
# Remove ancestor if all its descendants have been executed
|
||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||
self._remove_node(ancestor_id)
|
||||
|
||||
def _remove_node(self, node_id):
|
||||
"""
|
||||
Remove a node from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to remove.
|
||||
"""
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
del self.cache[cache_key]
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
del self.subcaches[subcache_key]
|
||||
|
||||
def clean_unused(self):
|
||||
"""
|
||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
"""
|
||||
Dump the cache and dependency graph for debugging.
|
||||
|
||||
Returns:
|
||||
A list containing the cache state and dependency graph.
|
||||
"""
|
||||
result = super().recursive_debug_dump()
|
||||
result.append({
|
||||
"descendants": self.descendants,
|
||||
"ancestors": self.ancestors,
|
||||
"executed_nodes": list(self.executed_nodes),
|
||||
})
|
||||
return result
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torchaudio
|
||||
import torch
|
||||
import comfy.model_management
|
||||
@@ -10,6 +12,7 @@ import random
|
||||
import hashlib
|
||||
import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
class EmptyLatentAudio:
|
||||
def __init__(self):
|
||||
@@ -164,7 +167,7 @@ class SaveAudio:
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
results: list[FileLocator] = []
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
|
||||
45
comfy_extras/nodes_cfg.py
Normal file
45
comfy_extras/nodes_cfg.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
|
||||
# https://github.com/WeichenFan/CFG-Zero-star
|
||||
def optimized_scale(positive, negative):
|
||||
positive_flat = positive.reshape(positive.shape[0], -1)
|
||||
negative_flat = negative.reshape(negative.shape[0], -1)
|
||||
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
||||
|
||||
class CFGZeroStar:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
def cfg_zero_star(args):
|
||||
guidance_scale = args['cond_scale']
|
||||
x = args['input']
|
||||
cond_p = args['cond_denoised']
|
||||
uncond_p = args['uncond_denoised']
|
||||
out = args["denoised"]
|
||||
alpha = optimized_scale(x - cond_p, x - uncond_p)
|
||||
|
||||
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
||||
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CFGZeroStar": CFGZeroStar
|
||||
}
|
||||
@@ -454,7 +454,7 @@ class SamplerCustom:
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"add_noise": ("BOOLEAN", {"default": True}),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
@@ -605,10 +605,16 @@ class DisableNoise:
|
||||
class RandomNoise(DisableNoise):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":{
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
}
|
||||
}
|
||||
return {
|
||||
"required": {
|
||||
"noise_seed": ("INT", {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xffffffffffffffff,
|
||||
"control_after_generate": True,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
def get_noise(self, noise_seed):
|
||||
return (Noise_RandomNoise(noise_seed),)
|
||||
|
||||
32
comfy_extras/nodes_hidream.py
Normal file
32
comfy_extras/nodes_hidream.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class QuadrupleCLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
@@ -38,7 +39,83 @@ class EmptyHunyuanLatentVideo:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples":latent}, )
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
class TextEncodeHunyuanVideo_ImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, clip_vision_output, prompt, image_interleave):
|
||||
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
class HunyuanImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
concat_latent_image = vae.encode(start_image)
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
if guidance_type == "v1 (concat)":
|
||||
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||
else:
|
||||
cond = {'guiding_frame_index': 0}
|
||||
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||
|
||||
out_latent["samples"] = latent
|
||||
return (positive, out_latent)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||
}
|
||||
|
||||
634
comfy_extras/nodes_hunyuan3d.py
Normal file
634
comfy_extras/nodes_hunyuan3d.py
Normal file
@@ -0,0 +1,634 @@
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import struct
|
||||
import numpy as np
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
|
||||
import folder_paths
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class EmptyLatentHunyuan3Dv2:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def generate(self, resolution, batch_size):
|
||||
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
||||
|
||||
|
||||
class Hunyuan3Dv2Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, clip_vision_output):
|
||||
embeds = clip_vision_output.last_hidden_state
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class Hunyuan3Dv2ConditioningMultiView:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {},
|
||||
"optional": {"front": ("CLIP_VISION_OUTPUT",),
|
||||
"left": ("CLIP_VISION_OUTPUT",),
|
||||
"back": ("CLIP_VISION_OUTPUT",),
|
||||
"right": ("CLIP_VISION_OUTPUT",), }}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, front=None, left=None, back=None, right=None):
|
||||
all_embeds = [front, left, back, right]
|
||||
out = []
|
||||
pos_embeds = None
|
||||
for i, e in enumerate(all_embeds):
|
||||
if e is not None:
|
||||
if pos_embeds is None:
|
||||
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
|
||||
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
|
||||
|
||||
embeds = torch.cat(out, dim=1)
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
|
||||
class VAEDecodeHunyuan3D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ),
|
||||
"vae": ("VAE", ),
|
||||
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
|
||||
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
|
||||
}}
|
||||
RETURN_TYPES = ("VOXEL",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def decode(self, vae, samples, num_chunks, octree_resolution):
|
||||
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return (voxels, )
|
||||
|
||||
|
||||
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
voxels = voxels.to(device)
|
||||
|
||||
binary = (voxels > threshold).float()
|
||||
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
||||
|
||||
D, H, W = binary.shape
|
||||
|
||||
neighbors = torch.tensor([
|
||||
[0, 0, 1],
|
||||
[0, 0, -1],
|
||||
[0, 1, 0],
|
||||
[0, -1, 0],
|
||||
[1, 0, 0],
|
||||
[-1, 0, 0]
|
||||
], device=device)
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(D, device=device),
|
||||
torch.arange(H, device=device),
|
||||
torch.arange(W, device=device),
|
||||
indexing='ij'
|
||||
)
|
||||
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
||||
|
||||
solid_mask = binary.flatten() > 0
|
||||
solid_indices = voxel_indices[solid_mask]
|
||||
|
||||
corner_offsets = [
|
||||
torch.tensor([
|
||||
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
|
||||
], device=device)
|
||||
]
|
||||
|
||||
all_vertices = []
|
||||
all_indices = []
|
||||
|
||||
vertex_count = 0
|
||||
|
||||
for face_idx, offset in enumerate(neighbors):
|
||||
neighbor_indices = solid_indices + offset
|
||||
|
||||
padded_indices = neighbor_indices + 1
|
||||
|
||||
is_exposed = padded[
|
||||
padded_indices[:, 0],
|
||||
padded_indices[:, 1],
|
||||
padded_indices[:, 2]
|
||||
] == 0
|
||||
|
||||
if not is_exposed.any():
|
||||
continue
|
||||
|
||||
exposed_indices = solid_indices[is_exposed]
|
||||
|
||||
corners = corner_offsets[face_idx].unsqueeze(0)
|
||||
|
||||
face_vertices = exposed_indices.unsqueeze(1) + corners
|
||||
|
||||
all_vertices.append(face_vertices.reshape(-1, 3))
|
||||
|
||||
num_faces = exposed_indices.shape[0]
|
||||
face_indices = torch.arange(
|
||||
vertex_count,
|
||||
vertex_count + 4 * num_faces,
|
||||
device=device
|
||||
).reshape(-1, 4)
|
||||
|
||||
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
|
||||
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
|
||||
|
||||
vertex_count += 4 * num_faces
|
||||
|
||||
if len(all_vertices) > 0:
|
||||
vertices = torch.cat(all_vertices, dim=0)
|
||||
faces = torch.cat(all_indices, dim=0)
|
||||
else:
|
||||
vertices = torch.zeros((1, 3))
|
||||
faces = torch.zeros((1, 3))
|
||||
|
||||
v_min = 0
|
||||
v_max = max(voxels.shape)
|
||||
|
||||
vertices = vertices - (v_min + v_max) / 2
|
||||
|
||||
scale = (v_max - v_min) / 2
|
||||
if scale > 0:
|
||||
vertices = vertices / scale
|
||||
|
||||
vertices = torch.fliplr(vertices)
|
||||
return vertices, faces
|
||||
|
||||
def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
voxels = voxels.to(device)
|
||||
|
||||
D, H, W = voxels.shape
|
||||
|
||||
padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(D, device=device),
|
||||
torch.arange(H, device=device),
|
||||
torch.arange(W, device=device),
|
||||
indexing='ij'
|
||||
)
|
||||
cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
||||
|
||||
corner_offsets = torch.tensor([
|
||||
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
|
||||
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
|
||||
], device=device)
|
||||
|
||||
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
|
||||
for c, (dz, dy, dx) in enumerate(corner_offsets):
|
||||
corner_values[:, c] = padded[
|
||||
cell_positions[:, 0] + dz,
|
||||
cell_positions[:, 1] + dy,
|
||||
cell_positions[:, 2] + dx
|
||||
]
|
||||
|
||||
corner_signs = corner_values > threshold
|
||||
has_inside = torch.any(corner_signs, dim=1)
|
||||
has_outside = torch.any(~corner_signs, dim=1)
|
||||
contains_surface = has_inside & has_outside
|
||||
|
||||
active_cells = cell_positions[contains_surface]
|
||||
active_signs = corner_signs[contains_surface]
|
||||
active_values = corner_values[contains_surface]
|
||||
|
||||
if active_cells.shape[0] == 0:
|
||||
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
|
||||
|
||||
edges = torch.tensor([
|
||||
[0, 1], [0, 2], [0, 4], [1, 3],
|
||||
[1, 5], [2, 3], [2, 6], [3, 7],
|
||||
[4, 5], [4, 6], [5, 7], [6, 7]
|
||||
], device=device)
|
||||
|
||||
cell_vertices = {}
|
||||
progress = comfy.utils.ProgressBar(100)
|
||||
|
||||
for edge_idx, (e1, e2) in enumerate(edges):
|
||||
progress.update(1)
|
||||
crossing = active_signs[:, e1] != active_signs[:, e2]
|
||||
if not crossing.any():
|
||||
continue
|
||||
|
||||
cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
|
||||
|
||||
v1 = active_values[cell_indices, e1]
|
||||
v2 = active_values[cell_indices, e2]
|
||||
|
||||
t = torch.zeros_like(v1, device=device)
|
||||
denom = v2 - v1
|
||||
valid = denom != 0
|
||||
t[valid] = (threshold - v1[valid]) / denom[valid]
|
||||
t[~valid] = 0.5
|
||||
|
||||
p1 = corner_offsets[e1].float()
|
||||
p2 = corner_offsets[e2].float()
|
||||
|
||||
intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
|
||||
|
||||
for i, point in zip(cell_indices.tolist(), intersection):
|
||||
if i not in cell_vertices:
|
||||
cell_vertices[i] = []
|
||||
cell_vertices[i].append(point)
|
||||
|
||||
# Calculate the final vertices as the average of intersection points for each cell
|
||||
vertices = []
|
||||
vertex_lookup = {}
|
||||
|
||||
vert_progress_mod = round(len(cell_vertices)/50)
|
||||
|
||||
for i, points in cell_vertices.items():
|
||||
if not i % vert_progress_mod:
|
||||
progress.update(1)
|
||||
|
||||
if points:
|
||||
vertex = torch.stack(points).mean(dim=0)
|
||||
vertex = vertex + active_cells[i].float()
|
||||
vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
|
||||
vertices.append(vertex)
|
||||
|
||||
if not vertices:
|
||||
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
|
||||
|
||||
final_vertices = torch.stack(vertices)
|
||||
|
||||
inside_corners_mask = active_signs
|
||||
outside_corners_mask = ~active_signs
|
||||
|
||||
inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
|
||||
outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
|
||||
|
||||
inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
|
||||
outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
|
||||
|
||||
for i in range(8):
|
||||
mask_inside = inside_corners_mask[:, i].unsqueeze(1)
|
||||
mask_outside = outside_corners_mask[:, i].unsqueeze(1)
|
||||
inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
|
||||
outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
|
||||
|
||||
inside_pos /= inside_counts
|
||||
outside_pos /= outside_counts
|
||||
gradients = inside_pos - outside_pos
|
||||
|
||||
pos_dirs = torch.tensor([
|
||||
[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]
|
||||
], device=device)
|
||||
|
||||
cross_products = [
|
||||
torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
|
||||
for i in range(3) for j in range(i+1, 3)
|
||||
]
|
||||
|
||||
faces = []
|
||||
all_keys = set(vertex_lookup.keys())
|
||||
|
||||
face_progress_mod = round(len(active_cells)/38*3)
|
||||
|
||||
for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
|
||||
dir_i = pos_dirs[i]
|
||||
dir_j = pos_dirs[j]
|
||||
cross_product = cross_products[pair_idx]
|
||||
|
||||
ni_positions = active_cells + dir_i
|
||||
nj_positions = active_cells + dir_j
|
||||
diag_positions = active_cells + dir_i + dir_j
|
||||
|
||||
alignments = torch.matmul(gradients, cross_product)
|
||||
|
||||
valid_quads = []
|
||||
quad_indices = []
|
||||
|
||||
for idx, active_cell in enumerate(active_cells):
|
||||
if not idx % face_progress_mod:
|
||||
progress.update(1)
|
||||
cell_key = tuple(active_cell.tolist())
|
||||
ni_key = tuple(ni_positions[idx].tolist())
|
||||
nj_key = tuple(nj_positions[idx].tolist())
|
||||
diag_key = tuple(diag_positions[idx].tolist())
|
||||
|
||||
if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
|
||||
v0 = vertex_lookup[cell_key]
|
||||
v1 = vertex_lookup[ni_key]
|
||||
v2 = vertex_lookup[nj_key]
|
||||
v3 = vertex_lookup[diag_key]
|
||||
|
||||
valid_quads.append((v0, v1, v2, v3))
|
||||
quad_indices.append(idx)
|
||||
|
||||
for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
|
||||
cell_idx = quad_indices[q_idx]
|
||||
if alignments[cell_idx] > 0:
|
||||
faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
|
||||
faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
|
||||
else:
|
||||
faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
|
||||
faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
|
||||
|
||||
if faces:
|
||||
faces = torch.stack(faces)
|
||||
else:
|
||||
faces = torch.zeros((0, 3), dtype=torch.long, device=device)
|
||||
|
||||
v_min = 0
|
||||
v_max = max(D, H, W)
|
||||
|
||||
final_vertices = final_vertices - (v_min + v_max) / 2
|
||||
|
||||
scale = (v_max - v_min) / 2
|
||||
if scale > 0:
|
||||
final_vertices = final_vertices / scale
|
||||
|
||||
final_vertices = torch.fliplr(final_vertices)
|
||||
|
||||
return final_vertices, faces
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices, faces):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
|
||||
|
||||
class VoxelToMeshBasic:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, threshold):
|
||||
vertices = []
|
||||
faces = []
|
||||
for x in voxel.data:
|
||||
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
|
||||
class VoxelToMesh:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"algorithm": (["surface net", "basic"], ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, algorithm, threshold):
|
||||
vertices = []
|
||||
faces = []
|
||||
|
||||
if algorithm == "basic":
|
||||
mesh_function = voxel_to_mesh
|
||||
elif algorithm == "surface net":
|
||||
mesh_function = voxel_to_mesh_surfnet
|
||||
|
||||
for x in voxel.data:
|
||||
v, f = mesh_function(x, threshold=threshold, device=None)
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
Parameters:
|
||||
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
||||
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
|
||||
filepath: str - Output filepath (should end with .glb)
|
||||
"""
|
||||
|
||||
# Convert tensors to numpy arrays
|
||||
vertices_np = vertices.cpu().numpy().astype(np.float32)
|
||||
faces_np = faces.cpu().numpy().astype(np.uint32)
|
||||
|
||||
vertices_buffer = vertices_np.tobytes()
|
||||
indices_buffer = faces_np.tobytes()
|
||||
|
||||
def pad_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b'\x00' * padding_length
|
||||
|
||||
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
|
||||
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
|
||||
|
||||
buffer_data = vertices_buffer_padded + indices_buffer_padded
|
||||
|
||||
vertices_byte_length = len(vertices_buffer)
|
||||
vertices_byte_offset = 0
|
||||
indices_byte_length = len(indices_buffer)
|
||||
indices_byte_offset = len(vertices_buffer_padded)
|
||||
|
||||
gltf = {
|
||||
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||
"buffers": [
|
||||
{
|
||||
"byteLength": len(buffer_data)
|
||||
}
|
||||
],
|
||||
"bufferViews": [
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": vertices_byte_offset,
|
||||
"byteLength": vertices_byte_length,
|
||||
"target": 34962 # ARRAY_BUFFER
|
||||
},
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": indices_byte_offset,
|
||||
"byteLength": indices_byte_length,
|
||||
"target": 34963 # ELEMENT_ARRAY_BUFFER
|
||||
}
|
||||
],
|
||||
"accessors": [
|
||||
{
|
||||
"bufferView": 0,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5126, # FLOAT
|
||||
"count": len(vertices_np),
|
||||
"type": "VEC3",
|
||||
"max": vertices_np.max(axis=0).tolist(),
|
||||
"min": vertices_np.min(axis=0).tolist()
|
||||
},
|
||||
{
|
||||
"bufferView": 1,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5125, # UNSIGNED_INT
|
||||
"count": faces_np.size,
|
||||
"type": "SCALAR"
|
||||
}
|
||||
],
|
||||
"meshes": [
|
||||
{
|
||||
"primitives": [
|
||||
{
|
||||
"attributes": {
|
||||
"POSITION": 0
|
||||
},
|
||||
"indices": 1,
|
||||
"mode": 4 # TRIANGLES
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"mesh": 0
|
||||
}
|
||||
],
|
||||
"scenes": [
|
||||
{
|
||||
"nodes": [0]
|
||||
}
|
||||
],
|
||||
"scene": 0
|
||||
}
|
||||
|
||||
if metadata is not None:
|
||||
gltf["asset"]["extras"] = metadata
|
||||
|
||||
# Convert the JSON to bytes
|
||||
gltf_json = json.dumps(gltf).encode('utf8')
|
||||
|
||||
def pad_json_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b' ' * padding_length
|
||||
|
||||
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
|
||||
|
||||
# Create the GLB header
|
||||
# Magic glTF
|
||||
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
|
||||
|
||||
# Create JSON chunk header (chunk type 0)
|
||||
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
|
||||
|
||||
# Create BIN chunk header (chunk type 1)
|
||||
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
|
||||
|
||||
# Write the GLB file
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(glb_header)
|
||||
f.write(json_chunk_header)
|
||||
f.write(gltf_json_padded)
|
||||
f.write(bin_chunk_header)
|
||||
f.write(buffer_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
class SaveGLB:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"mesh": ("MESH", ),
|
||||
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return {"ui": {"3d": results}}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
|
||||
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
|
||||
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
|
||||
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
|
||||
"VoxelToMeshBasic": VoxelToMeshBasic,
|
||||
"VoxelToMesh": VoxelToMesh,
|
||||
"SaveGLB": SaveGLB,
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import nodes
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
@@ -9,6 +11,8 @@ import numpy as np
|
||||
import json
|
||||
import os
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
class ImageCrop:
|
||||
@@ -99,7 +103,7 @@ class SaveAnimatedWEBP:
|
||||
method = self.methods.get(method)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
results: list[FileLocator] = []
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
|
||||
@@ -19,15 +19,10 @@ class Load3D():
|
||||
"image": ("LOAD_3D", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -35,22 +30,18 @@ class Load3D():
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
if isinstance(image, dict):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
return output_image, output_mask, model_file,
|
||||
else:
|
||||
# to avoid the format is not dict which will happen the FE code is not compatibility to core,
|
||||
# we need to this to double-check, it can be removed after merged FE into the core
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, output_mask = load_image_node.load_image(image=image_path)
|
||||
return output_image, output_mask, model_file,
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@@ -66,16 +57,10 @@ class Load3DAnimation():
|
||||
"image": ("LOAD_3D_ANIMATION", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"animation_speed": (["0.1", "0.5", "1", "1.5", "2"], {"default": "1"}),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -83,31 +68,40 @@ class Load3DAnimation():
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
if isinstance(image, dict):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
|
||||
return output_image, output_mask, model_file,
|
||||
else:
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, output_mask = load_image_node.load_image(image=image_path)
|
||||
return output_image, output_mask, model_file,
|
||||
return output_image, output_mask, model_file, normal_image
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
RETURN_TYPES = ()
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, model_file, **kwargs):
|
||||
return {"ui": {"model_file": [model_file]}, "result": ()}
|
||||
|
||||
class Preview3DAnimation():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@@ -124,11 +118,13 @@ class Preview3D():
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Load3D": Load3D,
|
||||
"Load3DAnimation": Load3DAnimation,
|
||||
"Preview3D": Preview3D
|
||||
"Preview3D": Preview3D,
|
||||
"Preview3DAnimation": Preview3DAnimation
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Load3D": "Load 3D",
|
||||
"Load3DAnimation": "Load 3D - Animation",
|
||||
"Preview3D": "Preview 3D"
|
||||
"Preview3D": "Preview 3D",
|
||||
"Preview3DAnimation": "Preview 3D - Animation"
|
||||
}
|
||||
|
||||
29
comfy_extras/nodes_lotus.py
Normal file
29
comfy_extras/nodes_lotus.py
Normal file
File diff suppressed because one or more lines are too long
@@ -1,9 +1,14 @@
|
||||
import io
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import comfy.utils
|
||||
import math
|
||||
import numpy as np
|
||||
import av
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
class EmptyLTXVLatentVideo:
|
||||
@classmethod
|
||||
@@ -33,7 +38,6 @@ class LTXVImgToVideo:
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
@@ -42,16 +46,219 @@ class LTXVImgToVideo:
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "generate"
|
||||
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale):
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
|
||||
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
latent[:, :, :t.shape[2]] = t
|
||||
return (positive, negative, {"samples": latent}, )
|
||||
|
||||
conditioning_latent_frames_mask = torch.ones(
|
||||
(batch_size, 1, latent.shape[2], 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=latent.device,
|
||||
)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
|
||||
|
||||
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
||||
|
||||
|
||||
def conditioning_get_any_value(conditioning, key, default=None):
|
||||
for t in conditioning:
|
||||
if key in t[1]:
|
||||
return t[1][key]
|
||||
return default
|
||||
|
||||
|
||||
def get_noise_mask(latent):
|
||||
noise_mask = latent.get("noise_mask", None)
|
||||
latent_image = latent["samples"]
|
||||
if noise_mask is None:
|
||||
batch_size, _, latent_length, _, _ = latent_image.shape
|
||||
noise_mask = torch.ones(
|
||||
(batch_size, 1, latent_length, 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=latent_image.device,
|
||||
)
|
||||
else:
|
||||
noise_mask = noise_mask.clone()
|
||||
return noise_mask
|
||||
|
||||
def get_keyframe_idxs(cond):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
|
||||
class LTXVAddGuide:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE",),
|
||||
"latent": ("LATENT",),
|
||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
|
||||
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
||||
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
||||
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
|
||||
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
|
||||
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
|
||||
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "generate"
|
||||
|
||||
def __init__(self):
|
||||
self._num_prefix_frames = 2
|
||||
self._patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def encode(self, vae, latent_width, latent_height, images, scale_factors):
|
||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1:
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
|
||||
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||
|
||||
return frame_idx, latent_idx
|
||||
|
||||
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
|
||||
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||
_, latent_coords = self._patchifier.patchify(guiding_latent)
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
|
||||
pixel_coords[:, 0] += frame_idx
|
||||
if keyframe_idxs is None:
|
||||
keyframe_idxs = pixel_coords
|
||||
else:
|
||||
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
|
||||
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||
return positive, negative, latent_image, noise_mask
|
||||
|
||||
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
|
||||
cond_length = guiding_latent.shape[2]
|
||||
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
|
||||
latent_image = latent_image.clone()
|
||||
noise_mask = noise_mask.clone()
|
||||
|
||||
latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
|
||||
noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask
|
||||
|
||||
return latent_image, noise_mask
|
||||
|
||||
def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
|
||||
scale_factors = vae.downscale_index_formula
|
||||
latent_image = latent["samples"]
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||
|
||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
||||
|
||||
positive, negative, latent_image, noise_mask = self.append_keyframe(
|
||||
positive,
|
||||
negative,
|
||||
frame_idx,
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t[:, :, :num_prefix_frames],
|
||||
strength,
|
||||
scale_factors,
|
||||
)
|
||||
|
||||
latent_idx += num_prefix_frames
|
||||
|
||||
t = t[:, :, num_prefix_frames:]
|
||||
if t.shape[2] == 0:
|
||||
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||
|
||||
latent_image, noise_mask = self.replace_latent_frames(
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t,
|
||||
latent_idx,
|
||||
strength,
|
||||
)
|
||||
|
||||
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||
|
||||
|
||||
class LTXVCropGuides:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"latent": ("LATENT",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "crop"
|
||||
|
||||
def __init__(self):
|
||||
self._patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def crop(self, positive, negative, latent):
|
||||
latent_image = latent["samples"].clone()
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
if num_keyframes == 0:
|
||||
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||
|
||||
latent_image = latent_image[:, :, :-num_keyframes]
|
||||
noise_mask = noise_mask[:, :, :-num_keyframes]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
||||
|
||||
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||
|
||||
|
||||
class LTXVConditioning:
|
||||
@@ -174,6 +381,76 @@ class LTXVScheduler:
|
||||
|
||||
return (sigmas,)
|
||||
|
||||
def encode_single_frame(output_file, image_array: np.ndarray, crf):
|
||||
container = av.open(output_file, "w", format="mp4")
|
||||
try:
|
||||
stream = container.add_stream(
|
||||
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||
)
|
||||
stream.height = image_array.shape[0]
|
||||
stream.width = image_array.shape[1]
|
||||
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
|
||||
format="yuv420p"
|
||||
)
|
||||
container.mux(stream.encode(av_frame))
|
||||
container.mux(stream.encode())
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
|
||||
def decode_single_frame(video_file):
|
||||
container = av.open(video_file)
|
||||
try:
|
||||
stream = next(s for s in container.streams if s.type == "video")
|
||||
frame = next(container.decode(stream))
|
||||
finally:
|
||||
container.close()
|
||||
return frame.to_ndarray(format="rgb24")
|
||||
|
||||
|
||||
def preprocess(image: torch.Tensor, crf=29):
|
||||
if crf == 0:
|
||||
return image
|
||||
|
||||
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
|
||||
with io.BytesIO() as output_file:
|
||||
encode_single_frame(output_file, image_array, crf)
|
||||
video_bytes = output_file.getvalue()
|
||||
with io.BytesIO(video_bytes) as video_file:
|
||||
image_array = decode_single_frame(video_file)
|
||||
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
|
||||
return tensor
|
||||
|
||||
|
||||
class LTXVPreprocess:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"img_compression": (
|
||||
"INT",
|
||||
{
|
||||
"default": 35,
|
||||
"min": 0,
|
||||
"max": 100,
|
||||
"tooltip": "Amount of compression to apply on image.",
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
FUNCTION = "preprocess"
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("output_image",)
|
||||
CATEGORY = "image"
|
||||
|
||||
def preprocess(self, image, img_compression):
|
||||
output_images = []
|
||||
for i in range(image.shape[0]):
|
||||
output_images.append(preprocess(image[i], img_compression))
|
||||
return (torch.stack(output_images),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
||||
@@ -181,4 +458,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingLTXV": ModelSamplingLTXV,
|
||||
"LTXVConditioning": LTXVConditioning,
|
||||
"LTXVScheduler": LTXVScheduler,
|
||||
"LTXVAddGuide": LTXVAddGuide,
|
||||
"LTXVPreprocess": LTXVPreprocess,
|
||||
"LTXVCropGuides": LTXVCropGuides,
|
||||
}
|
||||
|
||||
104
comfy_extras/nodes_lumina2.py
Normal file
104
comfy_extras/nodes_lumina2.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
import torch
|
||||
|
||||
|
||||
class RenormCFG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, cfg_trunc, renorm_cfg):
|
||||
def renorm_cfg_func(args):
|
||||
cond_denoised = args["cond_denoised"]
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
cond_scale = args["cond_scale"]
|
||||
timestep = args["timestep"]
|
||||
x_orig = args["input"]
|
||||
in_channels = model.model.diffusion_model.in_channels
|
||||
|
||||
if timestep[0] < cfg_trunc:
|
||||
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
|
||||
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
|
||||
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
|
||||
half_rest = cond_rest
|
||||
|
||||
if float(renorm_cfg) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(cond_eps
|
||||
, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
|
||||
)
|
||||
max_new_norm = ori_pos_norm * float(renorm_cfg)
|
||||
new_pos_norm = torch.linalg.vector_norm(
|
||||
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
|
||||
)
|
||||
if new_pos_norm >= max_new_norm:
|
||||
half_eps = half_eps * (max_new_norm / new_pos_norm)
|
||||
else:
|
||||
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
|
||||
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
|
||||
half_eps = cond_eps
|
||||
half_rest = cond_rest
|
||||
|
||||
cfg_result = torch.cat([half_eps, half_rest], dim=1)
|
||||
|
||||
# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale
|
||||
|
||||
return x_orig - cfg_result
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(renorm_cfg_func)
|
||||
return (m, )
|
||||
|
||||
|
||||
class CLIPTextEncodeLumina2(ComfyNodeABC):
|
||||
SYSTEM_PROMPT = {
|
||||
"superior": "You are an assistant designed to generate superior images with the superior "\
|
||||
"degree of image-text alignment based on textual prompts or user prompts.",
|
||||
"alignment": "You are an assistant designed to generate high-quality images with the "\
|
||||
"highest degree of image-text alignment based on textual prompts."
|
||||
}
|
||||
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
|
||||
"Superior: You are an assistant designed to generate superior images with the superior "\
|
||||
"degree of image-text alignment based on textual prompts or user prompts. "\
|
||||
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
|
||||
"degree of image-text alignment based on textual prompts."
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}),
|
||||
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = (IO.CONDITIONING,)
|
||||
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||
|
||||
def encode(self, clip, user_prompt, system_prompt):
|
||||
if clip is None:
|
||||
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
|
||||
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt]
|
||||
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
|
||||
tokens = clip.tokenize(prompt)
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2,
|
||||
"RenormCFG": RenormCFG
|
||||
}
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2",
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import numpy as np
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
@@ -87,6 +88,7 @@ class ImageCompositeMasked:
|
||||
CATEGORY = "image"
|
||||
|
||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||
destination = destination.clone().movedim(-1, 1)
|
||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||
return (output,)
|
||||
|
||||
@@ -3,6 +3,8 @@ import comfy.model_sampling
|
||||
import comfy.latent_formats
|
||||
import nodes
|
||||
import torch
|
||||
import node_helpers
|
||||
|
||||
|
||||
class LCM(comfy.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
@@ -18,10 +20,6 @@ class LCM(comfy.model_sampling.EPS):
|
||||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
class X0(comfy.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
return model_output
|
||||
|
||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||
original_timesteps = 50
|
||||
|
||||
@@ -54,7 +52,7 @@ class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["eps", "v_prediction", "lcm", "x0"],),
|
||||
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
|
||||
"zsnr": ("BOOLEAN", {"default": False}),
|
||||
}}
|
||||
|
||||
@@ -75,7 +73,9 @@ class ModelSamplingDiscrete:
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
elif sampling == "x0":
|
||||
sampling_type = X0
|
||||
sampling_type = comfy.model_sampling.X0
|
||||
elif sampling == "img_to_img":
|
||||
sampling_type = comfy.model_sampling.IMG_TO_IMG
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
@@ -294,6 +294,24 @@ class RescaleCFG:
|
||||
m.set_model_sampler_cfg_function(rescale_cfg)
|
||||
return (m, )
|
||||
|
||||
class ModelComputeDtype:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"dtype": (["default", "fp32", "fp16", "bf16"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/debug/model"
|
||||
|
||||
def patch(self, model, dtype):
|
||||
m = model.clone()
|
||||
m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype))
|
||||
return (m, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||
@@ -303,4 +321,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
|
||||
"ModelSamplingFlux": ModelSamplingFlux,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
"ModelComputeDtype": ModelComputeDtype,
|
||||
}
|
||||
|
||||
@@ -196,6 +196,78 @@ class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["extra_pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["affline_norm."] = argument
|
||||
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.block{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["extra_pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["affline_norm."] = argument
|
||||
|
||||
|
||||
for i in range(36):
|
||||
arg_dict["blocks.block{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["patch_embedding."] = argument
|
||||
arg_dict["time_embedding."] = argument
|
||||
arg_dict["time_projection."] = argument
|
||||
arg_dict["text_embedding."] = argument
|
||||
arg_dict["img_emb."] = argument
|
||||
|
||||
for i in range(40):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["head."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@@ -206,4 +278,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||
"ModelMergeMochiPreview": ModelMergeMochiPreview,
|
||||
"ModelMergeLTXV": ModelMergeLTXV,
|
||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch
|
||||
import comfy.model_management
|
||||
|
||||
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
|
||||
import kornia.color
|
||||
|
||||
|
||||
class Morphology:
|
||||
@@ -40,8 +41,45 @@ class Morphology:
|
||||
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
|
||||
return (img_out,)
|
||||
|
||||
|
||||
class ImageRGBToYUV:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("Y", "U", "V")
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def execute(self, image):
|
||||
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
|
||||
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
|
||||
|
||||
class ImageYUVToRGB:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"Y": ("IMAGE",),
|
||||
"U": ("IMAGE",),
|
||||
"V": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def execute(self, Y, U, V):
|
||||
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
|
||||
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
|
||||
return (out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Morphology": Morphology,
|
||||
"ImageRGBToYUV": ImageRGBToYUV,
|
||||
"ImageYUVToRGB": ImageYUVToRGB,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user