Compare commits
275 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a616b81c1 | ||
|
|
3bed56bb13 | ||
|
|
4e402b11c6 | ||
|
|
48272448ad | ||
|
|
f7695b5f9e | ||
|
|
452179fe4f | ||
|
|
bf9a90a145 | ||
|
|
c1b92b719d | ||
|
|
cdc3b97dd5 | ||
|
|
8d4e06324f | ||
|
|
57e8bf6a9f | ||
|
|
0ee322ec5f | ||
|
|
79d5ceae6e | ||
|
|
2d5b3e0078 | ||
|
|
8e4118c0de | ||
|
|
3fc6ebcdd7 | ||
|
|
20a560eb97 | ||
|
|
82c5308561 | ||
|
|
26fb2c68e8 | ||
|
|
bf2650a80e | ||
|
|
53646e0f32 | ||
|
|
20879c78f9 | ||
|
|
b666539595 | ||
|
|
95d8713482 | ||
|
|
0d4e29f13f | ||
|
|
497db6212f | ||
|
|
24dc581dc3 | ||
|
|
4c82741b54 | ||
|
|
15c39ea757 | ||
|
|
b7143b74ce | ||
|
|
61196d8857 | ||
|
|
b4526d3fc3 | ||
|
|
3d802710e7 | ||
|
|
7126ecffde | ||
|
|
ab885b33ba | ||
|
|
839ed3368e | ||
|
|
6e8cdcd3cb | ||
|
|
e5c3f4b87f | ||
|
|
bc6be6c11e | ||
|
|
94323a26a7 | ||
|
|
5818f6cf51 | ||
|
|
0b734de449 | ||
|
|
5e16f1d24b | ||
|
|
2fd9c1308a | ||
|
|
8f0009aad0 | ||
|
|
41444b5236 | ||
|
|
772e620e32 | ||
|
|
07f6eeaa13 | ||
|
|
22535d0589 | ||
|
|
898615122f | ||
|
|
156a28786b | ||
|
|
f498d855ba | ||
|
|
b699a15062 | ||
|
|
9cc90ee3eb | ||
|
|
9a0a5d32ee | ||
|
|
d9f90965c8 | ||
|
|
41886af138 | ||
|
|
22a1d7ce78 | ||
|
|
4ac401af2b | ||
|
|
5fb59c8475 | ||
|
|
122c9ca1ce | ||
|
|
3b9a6cf2b1 | ||
|
|
3748e7ef7a | ||
|
|
8ebf2d8831 | ||
|
|
a72d152b0c | ||
|
|
eb476e6ea9 | ||
|
|
2d28b0b479 | ||
|
|
8b275ce5be | ||
|
|
2a18e98ccf | ||
|
|
8a5281006f | ||
|
|
bdeb1c171c | ||
|
|
9c1ed58ef2 | ||
|
|
8b90e50979 | ||
|
|
6ee066a14f | ||
|
|
dd5b57e3d7 | ||
|
|
75a818c720 | ||
|
|
2865f913f7 | ||
|
|
b49616f951 | ||
|
|
5e29e7a488 | ||
|
|
8afb97cd3f | ||
|
|
69694f40b3 | ||
|
|
c49025f01b | ||
|
|
696672905f | ||
|
|
6c9dbde7de | ||
|
|
ee8abf0cff | ||
|
|
fabf449feb | ||
|
|
cc9cf6d1bd | ||
|
|
1c8286a44b | ||
|
|
1af4a47fd1 | ||
|
|
f2aaa0a475 | ||
|
|
daa1565b93 | ||
|
|
09fdb2b269 | ||
|
|
65a8659182 | ||
|
|
770ab200f2 | ||
|
|
954683d0db | ||
|
|
30c0c81351 | ||
|
|
13b0ff8a6f | ||
|
|
c320801187 | ||
|
|
c0b0cfaeec | ||
|
|
669d9e4c67 | ||
|
|
9ee0a6553a | ||
|
|
5cbb01bc2f | ||
|
|
c3ffbae067 | ||
|
|
d605677b33 | ||
|
|
ce759b7db6 | ||
|
|
52810907e2 | ||
|
|
af8cf79a2d | ||
|
|
66b0961a46 | ||
|
|
754597c8a9 | ||
|
|
915fdb5745 | ||
|
|
5a8a48931a | ||
|
|
8ce2a1052c | ||
|
|
f82314fcfc | ||
|
|
0075c6d096 | ||
|
|
83ca891118 | ||
|
|
f9f9faface | ||
|
|
471cd3eace | ||
|
|
a68bbafddb | ||
|
|
73e3a9e676 | ||
|
|
518c0dc2fe | ||
|
|
ce0542e10b | ||
|
|
8473019d40 | ||
|
|
89f15894dd | ||
|
|
67158994a4 | ||
|
|
7390ff3b1e | ||
|
|
0bedfb26af | ||
|
|
f71cfd2687 | ||
|
|
c695c4af7f | ||
|
|
0dbba9f751 | ||
|
|
f584758271 | ||
|
|
95b7cf9bbe | ||
|
|
191a0d56b4 | ||
|
|
3c60ecd7a8 | ||
|
|
7ae6626723 | ||
|
|
6632365e16 | ||
|
|
ad07796777 | ||
|
|
1b80895285 | ||
|
|
5f9d5a244b | ||
|
|
14eba07acd | ||
|
|
4b2f0d9413 | ||
|
|
25eac1d780 | ||
|
|
e38c94228b | ||
|
|
203942c8b2 | ||
|
|
3c72c89a52 | ||
|
|
614377abd6 | ||
|
|
8dfa0cc552 | ||
|
|
e5ecdfdd2d | ||
|
|
7d29fbf74b | ||
|
|
2c641e64ad | ||
|
|
7d2467e830 | ||
|
|
6f021d8aa0 | ||
|
|
d854ed0bcf | ||
|
|
abcd006b8c | ||
|
|
d985d1d7dc | ||
|
|
d1cdf51e1b | ||
|
|
b4626ab93e | ||
|
|
a9e459c2a4 | ||
|
|
3bb4dec720 | ||
|
|
8733191563 | ||
|
|
83b01f960a | ||
|
|
d72e871cfa | ||
|
|
037c3159b6 | ||
|
|
bdd4a22a2e | ||
|
|
fdf37566ef | ||
|
|
08c8968482 | ||
|
|
479a427a48 | ||
|
|
3a0eeee320 | ||
|
|
447da7ea86 | ||
|
|
9c41bc8d10 | ||
|
|
6ad0ddbae4 | ||
|
|
a55142f904 | ||
|
|
5718ef69bb | ||
|
|
13ecf10a92 | ||
|
|
7a415f47a9 | ||
|
|
89fa2fca24 | ||
|
|
364b69e931 | ||
|
|
dc96a1ae19 | ||
|
|
2d810b081e | ||
|
|
9f7e9f0547 | ||
|
|
a355f38ecc | ||
|
|
38c69080c7 | ||
|
|
70a708d726 | ||
|
|
e7d4782736 | ||
|
|
3326bdfd4e | ||
|
|
68bb885d22 | ||
|
|
ad66f7c7d8 | ||
|
|
de8e8e3b0d | ||
|
|
a1e71cfad1 | ||
|
|
0bfc7cc998 | ||
|
|
7183fd1665 | ||
|
|
254838f23c | ||
|
|
0b7dfa986d | ||
|
|
d514bb38ee | ||
|
|
0849c80e2a | ||
|
|
56e8f5e4fd | ||
|
|
e813abbb2c | ||
|
|
5e68a4ce67 | ||
|
|
ca08597670 | ||
|
|
f48e390032 | ||
|
|
369a6dd2c4 | ||
|
|
b3ce8fb9fd | ||
|
|
cf80d28689 | ||
|
|
6fb44c4b7c | ||
|
|
d2247c1e61 | ||
|
|
cb12ad7049 | ||
|
|
f6b7194f64 | ||
|
|
7c6eb4fb29 | ||
|
|
b962db9952 | ||
|
|
d0b7ab88ba | ||
|
|
405b529545 | ||
|
|
9d720187f1 | ||
|
|
d247bc5a9c | ||
|
|
9f4daca9d9 | ||
|
|
b5d0f2a908 | ||
|
|
e760bf5c40 | ||
|
|
36c83cdbba | ||
|
|
81778a7feb | ||
|
|
bc94662b31 | ||
|
|
9fa8faa44a | ||
|
|
9a7444e39f | ||
|
|
54fca4a218 | ||
|
|
cd4955367e | ||
|
|
8354203d95 | ||
|
|
e0b41243b4 | ||
|
|
619263d4a6 | ||
|
|
e3b0402bb7 | ||
|
|
967867d48c | ||
|
|
cbaac71bf5 | ||
|
|
3ab3516e46 | ||
|
|
9c5fca75f4 | ||
|
|
a5da4d0b3e | ||
|
|
32a60a7bac | ||
|
|
bb52934ba4 | ||
|
|
8aabd7c8c0 | ||
|
|
a09b29ca11 | ||
|
|
9bfee68773 | ||
|
|
ea77750759 | ||
|
|
c27ebeb1c2 | ||
|
|
0c7c98a965 | ||
|
|
dc2eb75b85 | ||
|
|
fa34efe3bd | ||
|
|
5cbaa9e07c | ||
|
|
c7427375ee | ||
|
|
22d1241a50 | ||
|
|
f04229b84d | ||
|
|
f067ad15d1 | ||
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 | ||
|
|
f1c2301697 | ||
|
|
8d31a6632f | ||
|
|
b643eae08b | ||
|
|
baa6b4dc36 | ||
|
|
d4aeefc297 | ||
|
|
587e7ca654 | ||
|
|
c90459eba0 | ||
|
|
04278afb10 | ||
|
|
935ae153e1 | ||
|
|
e91662e784 | ||
|
|
63fafaef45 | ||
|
|
ec28cd9136 | ||
|
|
6eb5d64522 | ||
|
|
10a79e9898 | ||
|
|
ea3f39bd69 | ||
|
|
b33cd61070 | ||
|
|
34eda0f853 | ||
|
|
d31e226650 | ||
|
|
b79fd7d92c | ||
|
|
38c22e631a | ||
|
|
6bbdcd28ae | ||
|
|
ab130001a8 | ||
|
|
ca4b8f30e0 | ||
|
|
70b84058c1 | ||
|
|
2ca8f6e23d | ||
|
|
7985ff88b9 |
@@ -14,7 +14,7 @@ run_cpu.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
|
||||
2
.github/workflows/pullrequest-ci-run.yml
vendored
2
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
|
||||
6
.github/workflows/stable-release.yml
vendored
6
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
default: "124"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
|
||||
|
||||
jobs:
|
||||
|
||||
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: 'Close stale issues'
|
||||
on:
|
||||
schedule:
|
||||
# Run daily at 430 am PT
|
||||
- cron: '30 11 * * *'
|
||||
permissions:
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||
days-before-stale: 30
|
||||
days-before-close: 7
|
||||
stale-issue-label: 'Stale'
|
||||
only-labels: 'User Support'
|
||||
exempt-all-assignees: true
|
||||
exempt-all-milestones: true
|
||||
4
.github/workflows/test-ci.yml
vendored
4
.github/workflows/test-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
|
||||
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
|
||||
30
.github/workflows/test-unit.yml
vendored
Normal file
30
.github/workflows/test-unit.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
@@ -12,7 +12,7 @@ on:
|
||||
description: 'extra dependencies'
|
||||
required: false
|
||||
type: string
|
||||
default: "\"numpy<2\""
|
||||
default: ""
|
||||
cu:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
@@ -23,13 +23,13 @@ on:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -13,13 +13,13 @@ on:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ extra_model_paths.yaml
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
!/web/extensions/core/
|
||||
|
||||
80
README.md
80
README.md
@@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular stable diffusion GUI and backend.**
|
||||
**The most powerful and modular diffusion model GUI and backend.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
@@ -28,7 +28,7 @@
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
@@ -39,7 +39,9 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@@ -73,35 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
| Keybind | Explanation |
|
||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Alt + Enter | Cancel current generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Alt + `+` | Canvas Zoom in |
|
||||
| Alt + `-` | Canvas Zoom out |
|
||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| `Ctrl` + `Enter` | Queue up current graph for generation |
|
||||
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
|
||||
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
|
||||
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
|
||||
| `Ctrl` + `S` | Save workflow |
|
||||
| `Ctrl` + `O` | Load workflow |
|
||||
| `Ctrl` + `A` | Select all nodes |
|
||||
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
|
||||
| `Ctrl` + `M` | Mute/unmute selected nodes |
|
||||
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| `Delete`/`Backspace` | Delete selected nodes |
|
||||
| `Ctrl` + `Backspace` | Delete the current graph |
|
||||
| `Space` | Move the canvas around when held and moving the cursor |
|
||||
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
|
||||
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
|
||||
| `Ctrl` + `D` | Load default graph |
|
||||
| `Alt` + `+` | Canvas Zoom in |
|
||||
| `Alt` + `-` | Canvas Zoom out |
|
||||
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| `P` | Pin/Unpin selected nodes |
|
||||
| `Ctrl` + `G` | Group selected nodes |
|
||||
| `Q` | Toggle visibility of the queue |
|
||||
| `H` | Toggle visibility of history |
|
||||
| `R` | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| Shift + Drag | Move multiple wires at once |
|
||||
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
||||
| `Shift` + Drag | Move multiple wires at once |
|
||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
`Ctrl` can also be replaced with `Cmd` instead for macOS users
|
||||
|
||||
# Installing
|
||||
|
||||
@@ -125,6 +129,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
@@ -135,17 +141,17 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
|
||||
@@ -207,6 +213,14 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
|
||||
|
||||
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
Only parts of the graph that have an output with all the correct inputs will be executed.
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
|
||||
class InternalRoutes:
|
||||
'''
|
||||
The top level web router for internal routes: /internal/*
|
||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||
Check README.md for more information.
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
@@ -18,6 +20,8 @@ class InternalRoutes:
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
@@ -31,6 +35,37 @@ class InternalRoutes:
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
|
||||
@self.routes.get('/logs/raw')
|
||||
async def get_logs(request):
|
||||
self.terminal_service.update_size()
|
||||
return web.json_response({
|
||||
"entries": list(app.logger.get_logs()),
|
||||
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
||||
})
|
||||
|
||||
@self.routes.patch('/logs/subscribe')
|
||||
async def subscribe_logs(request):
|
||||
json_data = await request.json()
|
||||
client_id = json_data["clientId"]
|
||||
enabled = json_data["enabled"]
|
||||
if enabled:
|
||||
self.terminal_service.subscribe(client_id)
|
||||
else:
|
||||
self.terminal_service.unsubscribe(client_id)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@self.routes.get('/folder_paths')
|
||||
async def get_folder_paths(request):
|
||||
response = {}
|
||||
for key in folder_names_and_paths:
|
||||
response[key] = folder_names_and_paths[key][0]
|
||||
return web.json_response(response)
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
|
||||
60
api_server/services/terminal_service.py
Normal file
60
api_server/services/terminal_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from app.logger import on_flush
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class TerminalService:
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
self.cols = None
|
||||
self.rows = None
|
||||
self.subscriptions = set()
|
||||
on_flush(self.send_messages)
|
||||
|
||||
def get_terminal_size(self):
|
||||
try:
|
||||
size = os.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
try:
|
||||
size = shutil.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
return (80, 24) # fallback to 80x24
|
||||
|
||||
def update_size(self):
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
return {"cols": self.cols, "rows": self.rows}
|
||||
|
||||
return None
|
||||
|
||||
def subscribe(self, client_id):
|
||||
self.subscriptions.add(client_id)
|
||||
|
||||
def unsubscribe(self, client_id):
|
||||
self.subscriptions.discard(client_id)
|
||||
|
||||
def send_messages(self, entries):
|
||||
if not len(entries) or not len(self.subscriptions):
|
||||
return
|
||||
|
||||
new_size = self.update_size()
|
||||
|
||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||
if client_id not in self.server.sockets:
|
||||
# Automatically unsub if the socket has disconnected
|
||||
self.unsubscribe(client_id)
|
||||
continue
|
||||
|
||||
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
||||
@@ -8,7 +8,7 @@ import zipfile
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
@@ -132,12 +132,13 @@ class FrontendManager:
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
@@ -150,7 +151,16 @@ class FrontendManager:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
provider = FrontEndProvider(repo_owner, repo_name)
|
||||
|
||||
if version.startswith("v"):
|
||||
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||
if os.path.exists(expected_path):
|
||||
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||
return expected_path
|
||||
|
||||
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||
|
||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
semantic_version = release["tag_name"].lstrip("v")
|
||||
@@ -158,15 +168,21 @@ class FrontendManager:
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||
)
|
||||
if not os.path.exists(web_root):
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
try:
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
finally:
|
||||
# Clean up the directory if it is empty, i.e. the download failed
|
||||
if not os.listdir(web_root):
|
||||
os.rmdir(web_root)
|
||||
|
||||
return web_root
|
||||
|
||||
@classmethod
|
||||
|
||||
73
app/logger.py
Normal file
73
app/logger.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
logs = None
|
||||
stdout_interceptor = None
|
||||
stderr_interceptor = None
|
||||
|
||||
|
||||
class LogInterceptor(io.TextIOWrapper):
|
||||
def __init__(self, stream, *args, **kwargs):
|
||||
buffer = stream.buffer
|
||||
encoding = stream.encoding
|
||||
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
||||
self._lock = threading.Lock()
|
||||
self._flush_callbacks = []
|
||||
self._logs_since_flush = []
|
||||
|
||||
def write(self, data):
|
||||
entry = {"t": datetime.now().isoformat(), "m": data}
|
||||
with self._lock:
|
||||
self._logs_since_flush.append(entry)
|
||||
|
||||
# Simple handling for cr to overwrite the last output if it isnt a full line
|
||||
# else logs just get full of progress messages
|
||||
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
||||
logs.pop()
|
||||
logs.append(entry)
|
||||
super().write(data)
|
||||
|
||||
def flush(self):
|
||||
super().flush()
|
||||
for cb in self._flush_callbacks:
|
||||
cb(self._logs_since_flush)
|
||||
self._logs_since_flush = []
|
||||
|
||||
def on_flush(self, callback):
|
||||
self._flush_callbacks.append(callback)
|
||||
|
||||
|
||||
def get_logs():
|
||||
return logs
|
||||
|
||||
|
||||
def on_flush(callback):
|
||||
if stdout_interceptor is not None:
|
||||
stdout_interceptor.on_flush(callback)
|
||||
if stderr_interceptor is not None:
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
|
||||
# Override output streams and log to buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
|
||||
global stdout_interceptor
|
||||
global stderr_interceptor
|
||||
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
||||
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
||||
|
||||
# Setup default global logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(log_level)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(stream_handler)
|
||||
@@ -1,38 +1,58 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
from folder_paths import user_directory
|
||||
import folder_paths
|
||||
from .app_settings import AppSettings
|
||||
from typing import TypedDict
|
||||
|
||||
default_user = "default"
|
||||
users_file = os.path.join(user_directory, "users.json")
|
||||
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
}
|
||||
|
||||
|
||||
class UserManager():
|
||||
def __init__(self):
|
||||
global user_directory
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
os.mkdir(user_directory)
|
||||
os.makedirs(user_directory, exist_ok=True)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(users_file):
|
||||
with open(users_file) as f:
|
||||
if os.path.isfile(self.get_users_file()):
|
||||
with open(self.get_users_file()) as f:
|
||||
self.users = json.load(f)
|
||||
else:
|
||||
self.users = {}
|
||||
else:
|
||||
self.users = {"default": "default"}
|
||||
|
||||
def get_users_file(self):
|
||||
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||
|
||||
def get_request_user_id(self, request):
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
@@ -44,7 +64,7 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
global user_directory
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
@@ -59,6 +79,10 @@ class UserManager():
|
||||
return None
|
||||
|
||||
if file is not None:
|
||||
# Check if filename is url encoded
|
||||
if "%" in file:
|
||||
file = parse.unquote(file)
|
||||
|
||||
# prevent leaving /{type}/{user}
|
||||
path = os.path.abspath(os.path.join(user_root, file))
|
||||
if os.path.commonpath((user_root, path)) != user_root:
|
||||
@@ -80,8 +104,7 @@ class UserManager():
|
||||
|
||||
self.users[user_id] = name
|
||||
|
||||
global users_file
|
||||
with open(users_file, "w") as f:
|
||||
with open(self.get_users_file(), "w") as f:
|
||||
json.dump(self.users, f)
|
||||
|
||||
return user_id
|
||||
@@ -112,25 +135,65 @@ class UserManager():
|
||||
|
||||
@routes.get("/userdata")
|
||||
async def listuserdata(request):
|
||||
"""
|
||||
List user data files in a specified directory.
|
||||
|
||||
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||
full file information, and path splitting.
|
||||
|
||||
Query Parameters:
|
||||
- dir (required): The directory to list files from.
|
||||
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||
|
||||
Returns:
|
||||
- 400: If 'dir' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 404: If the requested directory does not exist.
|
||||
- 200: JSON response with the list of files or file information.
|
||||
|
||||
The response format depends on the query parameters:
|
||||
- Default: List of relative file paths.
|
||||
- full_info=true: List of dictionaries with file details.
|
||||
- split=true (and full_info=false): List of lists, each containing path components.
|
||||
"""
|
||||
directory = request.rel_url.query.get('dir', '')
|
||||
if not directory:
|
||||
return web.Response(status=400)
|
||||
|
||||
return web.Response(status=400, text="Directory not provided")
|
||||
|
||||
path = self.get_request_user_filepath(request, directory)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
return web.Response(status=403, text="Invalid directory")
|
||||
|
||||
if not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
return web.Response(status=404, text="Directory not found")
|
||||
|
||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||
results = glob.glob(os.path.join(
|
||||
glob.escape(path), '**/*'), recursive=recurse)
|
||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
||||
|
||||
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
if split_path:
|
||||
results = [[x] + x.split(os.sep) for x in results]
|
||||
|
||||
# Use different patterns based on whether we're recursing or not
|
||||
if recurse:
|
||||
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||
else:
|
||||
pattern = os.path.join(glob.escape(path), '*')
|
||||
|
||||
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
|
||||
if full_info:
|
||||
return get_file_info(full_path, path)
|
||||
|
||||
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
|
||||
if split_path:
|
||||
return [rel_path] + rel_path.split('/')
|
||||
|
||||
return rel_path
|
||||
|
||||
results = [
|
||||
process_full_path(full_path)
|
||||
for full_path in glob.glob(pattern, recursive=recurse)
|
||||
if os.path.isfile(full_path)
|
||||
]
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@@ -138,14 +201,14 @@ class UserManager():
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
return web.Response(status=400)
|
||||
|
||||
|
||||
path = self.get_request_user_filepath(request, file)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
|
||||
if check_exists and not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
|
||||
return path
|
||||
|
||||
@routes.get("/userdata/{file}")
|
||||
@@ -153,25 +216,56 @@ class UserManager():
|
||||
path = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
|
||||
return web.FileResponse(path)
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
"""
|
||||
Upload or update a user data file.
|
||||
|
||||
This endpoint handles file uploads to a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Path Parameters:
|
||||
- file: The target file path (URL encoded if necessary).
|
||||
|
||||
Returns:
|
||||
- 400: If 'file' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 409: If overwrite=false and the file already exists.
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
|
||||
The request body should contain the raw file content to be written.
|
||||
"""
|
||||
path = get_user_data_path(request)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409)
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(path, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(path, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
@routes.delete("/userdata/{file}")
|
||||
@@ -181,25 +275,56 @@ class UserManager():
|
||||
return path
|
||||
|
||||
os.remove(path)
|
||||
|
||||
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.post("/userdata/{file}/move/{dest}")
|
||||
async def move_userdata(request):
|
||||
"""
|
||||
Move or rename a user data file.
|
||||
|
||||
This endpoint handles moving or renaming files within a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Path Parameters:
|
||||
- file: The source file path (URL encoded if necessary)
|
||||
- dest: The destination file path (URL encoded if necessary)
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Returns:
|
||||
- 400: If either 'file' or 'dest' parameter is missing
|
||||
- 403: If either requested path is not allowed
|
||||
- 404: If the source file does not exist
|
||||
- 409: If overwrite=false and the destination file already exists
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
"""
|
||||
source = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(source, str):
|
||||
return source
|
||||
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409)
|
||||
|
||||
print(f"moving '{source}' -> '{dest}'")
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
logging.info(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(dest, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(dest, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
122
comfy/cldm/dit_embedder.py
Normal file
122
comfy/cldm/dit_embedder.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||
|
||||
|
||||
class ControlNetEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
adm_in_channels: int,
|
||||
num_layers: int,
|
||||
main_model_double: int,
|
||||
double_y_emb: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.main_model_double = main_model_double
|
||||
self.dtype = dtype
|
||||
self.hidden_size = num_attention_heads * attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=pos_embed_max_size is None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.double_y_emb = double_y_emb
|
||||
if self.double_y_emb:
|
||||
self.orig_y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
self.y_embedder = VectorEmbedder(
|
||||
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
DismantledBlock(
|
||||
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
)
|
||||
|
||||
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
||||
# TODO double check this logic when 8b
|
||||
self.use_y_embedder = True
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.pos_embed_input = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
hint = None,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
x_shape = list(x.shape)
|
||||
x = self.x_embedder(x)
|
||||
if not self.double_y_emb:
|
||||
h = (x_shape[-2] + 1) // self.patch_size
|
||||
w = (x_shape[-1] + 1) // self.patch_size
|
||||
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
||||
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
if self.double_y_emb:
|
||||
y = self.orig_y_embedder(y)
|
||||
y = self.y_embedder(y)
|
||||
c = c + y
|
||||
|
||||
x = x + self.pos_embed_input(hint)
|
||||
|
||||
block_out = ()
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
||||
for i in range(len(self.transformer_blocks)):
|
||||
out = self.transformer_blocks[i](x, c)
|
||||
if not self.double_y_emb:
|
||||
x = out
|
||||
block_out += (self.controlnet_blocks[i](out),) * repeat
|
||||
|
||||
return {"output": block_out}
|
||||
@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks = None,
|
||||
control_latent_channels = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
for _ in range(len(self.joint_blocks)):
|
||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||
|
||||
if control_latent_channels is None:
|
||||
control_latent_channels = self.in_channels
|
||||
|
||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||
None,
|
||||
self.patch_size,
|
||||
self.in_channels,
|
||||
control_latent_channels,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=False,
|
||||
|
||||
@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||
@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
||||
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
|
||||
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
|
||||
@@ -92,6 +94,8 @@ class LatentPreviewMethod(enum.Enum):
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
@@ -134,7 +138,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
@@ -169,6 +173,8 @@ parser.add_argument(
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
@@ -179,10 +185,3 @@ if args.windows_standalone_build:
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
import logging
|
||||
logging_level = logging.INFO
|
||||
if args.verbose:
|
||||
logging_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||
|
||||
@@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
|
||||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
@@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
if model_type == "siglip_vision_model":
|
||||
self.class_embedding = None
|
||||
patch_bias = True
|
||||
else:
|
||||
num_patches = num_patches + 1
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
patch_bias = False
|
||||
|
||||
self.patch_embedding = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
bias=patch_bias,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
if self.class_embedding is not None:
|
||||
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
||||
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
@@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.output_layernorm = False
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
||||
@@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
|
||||
x = self.pre_layrnorm(x)
|
||||
#TODO: attention_mask?
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
if self.output_layernorm:
|
||||
x = self.post_layernorm(x)
|
||||
pooled_output = x
|
||||
else:
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
if "projection_dim" in config_dict:
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
|
||||
@@ -16,13 +16,18 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
@@ -35,6 +40,8 @@ class ClipVisionModel():
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -49,9 +56,9 @@ class ClipVisionModel():
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
@@ -109,8 +118,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
t = sd.pop(k)
|
||||
del t
|
||||
sd.pop(k)
|
||||
return clip
|
||||
|
||||
def load(ckpt_path):
|
||||
|
||||
13
comfy/clip_vision_siglip_384.json
Normal file
13
comfy/clip_vision_siglip_384.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 384,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
43
comfy/comfy_types/README.md
Normal file
43
comfy/comfy_types/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Comfy Typing
|
||||
## Type hinting for ComfyUI Node development
|
||||
|
||||
This module provides type hinting and concrete convenience types for node developers.
|
||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||
|
||||
```python
|
||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {"required": {}}
|
||||
```
|
||||
|
||||
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
|
||||
|
||||
# Types
|
||||
A few primary types are documented below. More complete information is available via the docstrings on each type.
|
||||
|
||||
## `IO`
|
||||
|
||||
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
|
||||
|
||||
- `ANY`: `"*"`
|
||||
- `NUMBER`: `"FLOAT,INT"`
|
||||
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
|
||||
|
||||
## `ComfyNodeABC`
|
||||
|
||||
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
|
||||
|
||||
### Type hinting for `INPUT_TYPES`
|
||||
|
||||

|
||||
|
||||
### `INPUT_TYPES` return dict
|
||||
|
||||

|
||||
|
||||
### Options for individual inputs
|
||||
|
||||

|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
|
||||
class UnetApplyFunction(Protocol):
|
||||
@@ -30,3 +31,15 @@ class UnetParams(TypedDict):
|
||||
|
||||
|
||||
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"UnetWrapperFunction",
|
||||
UnetApplyConds.__name__,
|
||||
UnetParams.__name__,
|
||||
UnetApplyFunction.__name__,
|
||||
IO.__name__,
|
||||
InputTypeDict.__name__,
|
||||
ComfyNodeABC.__name__,
|
||||
CheckLazyMixin.__name__,
|
||||
]
|
||||
28
comfy/comfy_types/examples/example_nodes.py
Normal file
28
comfy/comfy_types/examples/example_nodes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
"""An example node that just adds 1 to an input integer.
|
||||
|
||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
||||
* Not intended for use in ComfyUI.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
CATEGORY = "examples"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"input_int": (IO.INT, {"defaultInput": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("input_plus_one",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, input_int: int):
|
||||
return (input_int + 1,)
|
||||
BIN
comfy/comfy_types/examples/input_options.png
Normal file
BIN
comfy/comfy_types/examples/input_options.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
BIN
comfy/comfy_types/examples/input_types.png
Normal file
BIN
comfy/comfy_types/examples/input_types.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
274
comfy/comfy_types/node_typing.py
Normal file
274
comfy/comfy_types/node_typing.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class IO(StrEnum):
|
||||
"""Node input/output data types.
|
||||
|
||||
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
|
||||
"""
|
||||
|
||||
STRING = "STRING"
|
||||
IMAGE = "IMAGE"
|
||||
MASK = "MASK"
|
||||
LATENT = "LATENT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
GUIDER = "GUIDER"
|
||||
NOISE = "NOISE"
|
||||
CLIP = "CLIP"
|
||||
CONTROL_NET = "CONTROL_NET"
|
||||
VAE = "VAE"
|
||||
MODEL = "MODEL"
|
||||
CLIP_VISION = "CLIP_VISION"
|
||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||
STYLE_MODEL = "STYLE_MODEL"
|
||||
GLIGEN = "GLIGEN"
|
||||
UPSCALE_MODEL = "UPSCALE_MODEL"
|
||||
AUDIO = "AUDIO"
|
||||
WEBCAM = "WEBCAM"
|
||||
POINT = "POINT"
|
||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
|
||||
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
|
||||
"""
|
||||
NUMBER = "FLOAT,INT"
|
||||
"""A float or an int - could be either"""
|
||||
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
|
||||
"""Could be any of: string, float, int, or bool"""
|
||||
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
|
||||
"""
|
||||
|
||||
default: bool | str | float | int | list | tuple
|
||||
"""The default value of the widget"""
|
||||
defaultInput: bool
|
||||
"""Defaults to an input slot rather than a widget"""
|
||||
forceInput: bool
|
||||
"""`defaultInput` and also don't allow converting to a widget"""
|
||||
lazy: bool
|
||||
"""Declares that this input uses lazy evaluation"""
|
||||
rawLink: bool
|
||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||
tooltip: str
|
||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: float
|
||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||
max: float
|
||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||
step: float
|
||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||
round: float
|
||||
"""Floats are rounded by this value (``FLOAT``)"""
|
||||
# class InputTypeBoolean(InputTypeOptions):
|
||||
# default: bool
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
multiline: bool
|
||||
"""Use a multiline text box (``STRING``)"""
|
||||
placeholder: str
|
||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||
# Deprecated:
|
||||
# defaultVal: str
|
||||
dynamicPrompts: bool
|
||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||
|
||||
node_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
unique_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
prompt: Literal["PROMPT"]
|
||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||
dynprompt: Literal["DYNPROMPT"]
|
||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||
|
||||
|
||||
class InputTypeDict(TypedDict):
|
||||
"""Provides type hinting for node INPUT_TYPES.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
|
||||
"""
|
||||
|
||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes all inputs that must be connected for the node to execute."""
|
||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes inputs which do not need to be connected."""
|
||||
hidden: HiddenInputTypeDict
|
||||
"""Offers advanced functionality and server-client communication.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
"""
|
||||
|
||||
|
||||
class ComfyNodeABC(ABC):
|
||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
|
||||
"""
|
||||
|
||||
DESCRIPTION: str
|
||||
"""Node description, shown as a tooltip when hovering over the node.
|
||||
|
||||
Usage::
|
||||
|
||||
# Explicitly define the description
|
||||
DESCRIPTION = "Example description here."
|
||||
|
||||
# Use the docstring of the node class.
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
"""
|
||||
CATEGORY: str
|
||||
"""The category of the node, as per the "Add Node" menu.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
|
||||
"""
|
||||
EXPERIMENTAL: bool
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
"""Defines node inputs.
|
||||
|
||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
|
||||
"""
|
||||
return {"required": {}}
|
||||
|
||||
OUTPUT_NODE: bool
|
||||
"""Flags this node as an output node, causing any inputs it requires to be executed.
|
||||
|
||||
If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
From the docs:
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
|
||||
"""
|
||||
INPUT_IS_LIST: bool
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
|
||||
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||
|
||||
From the docs:
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
|
||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||
|
||||
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
|
||||
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
|
||||
|
||||
From the docs:
|
||||
|
||||
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
|
||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
||||
specifying which outputs which should be so treated.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
"""A tuple representing the outputs of this node.
|
||||
|
||||
Usage::
|
||||
|
||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
|
||||
"""
|
||||
|
||||
|
||||
class CheckLazyMixin:
|
||||
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
|
||||
|
||||
def check_lazy_status(self, **kwargs) -> list[str]:
|
||||
"""Returns a list of input names that should be evaluated.
|
||||
|
||||
This basic mixin impl. requires all inputs.
|
||||
|
||||
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
|
||||
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
|
||||
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
return need
|
||||
@@ -34,7 +34,11 @@ import comfy.t2i_adapter.adapter
|
||||
import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet_xlabs
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
@@ -60,7 +64,7 @@ class StrengthType(Enum):
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
@@ -72,20 +76,26 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
self.extra_concat_orig = []
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
if self.latent_format is not None:
|
||||
if vae is None:
|
||||
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||
self.vae = vae
|
||||
self.extra_concat_orig = extra_concat.copy()
|
||||
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||
return self
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
@@ -100,9 +110,9 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
@@ -110,6 +120,14 @@ class ControlBase:
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
out.append(self.extra_hooks)
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
@@ -123,6 +141,10 @@ class ControlBase:
|
||||
c.vae = self.vae
|
||||
c.extra_conds = self.extra_conds.copy()
|
||||
c.strength_type = self.strength_type
|
||||
c.concat_mask = self.concat_mask
|
||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
|
||||
c.preprocess_image = self.preprocess_image
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@@ -148,7 +170,7 @@ class ControlBase:
|
||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||
x *= (self.strength ** float(len(control_output) - i))
|
||||
|
||||
if x.dtype != output_dtype:
|
||||
if output_dtype is not None and x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
@@ -175,8 +197,8 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
||||
super().__init__(device)
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
|
||||
super().__init__()
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
@@ -189,11 +211,13 @@ class ControlNet(ControlBase):
|
||||
self.latent_format = latent_format
|
||||
self.extra_conds += extra_conds
|
||||
self.strength_type = strength_type
|
||||
self.concat_mask = concat_mask
|
||||
self.preprocess_image = preprocess_image
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
@@ -206,7 +230,6 @@ class ControlNet(ControlBase):
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
@@ -214,14 +237,26 @@ class ControlNet(ControlBase):
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
if self.latent_format is not None:
|
||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||
if len(self.extra_concat_orig) > 0:
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = c.to(self.cond_hint.device)
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
@@ -236,7 +271,7 @@ class ControlNet(ControlBase):
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
return self.control_merge(control, control_prev, output_dtype=None)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
@@ -320,8 +355,8 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
||||
ControlBase.__init__(self, device)
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
ControlBase.__init__(self)
|
||||
self.control_weights = control_weights
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.extra_conds += ["y"]
|
||||
@@ -377,19 +412,25 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def controlnet_config(sd):
|
||||
def controlnet_config(sd, model_options={}):
|
||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
unet_dtype = model_options.get("dtype", None)
|
||||
if unet_dtype is None:
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
controlnet_config = model_config.unet_config
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
@@ -404,24 +445,106 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
|
||||
def load_controlnet_mmdit(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
concat_mask = False
|
||||
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||
if control_latent_channels == 17: #inpaint controlnet
|
||||
concat_mask = True
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_controlnet_sd35(sd, model_options={}):
|
||||
control_type = -1
|
||||
if "control_type" in sd:
|
||||
control_type = round(sd.pop("control_type").item())
|
||||
|
||||
# blur_cnet = control_type == 0
|
||||
canny_cnet = control_type == 1
|
||||
depth_cnet = control_type == 2
|
||||
|
||||
new_sd = {}
|
||||
for k in comfy.utils.MMDIT_MAP_BASIC:
|
||||
if k[1] in sd:
|
||||
new_sd[k[0]] = sd.pop(k[1])
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
sd = new_sd
|
||||
|
||||
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
|
||||
depth = y_emb_shape[0] // 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
head_dim = hidden_size // num_heads
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
|
||||
patch_size=2,
|
||||
in_chans=16,
|
||||
num_layers=num_blocks,
|
||||
main_model_double=depth,
|
||||
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
|
||||
attention_head_dim=head_dim,
|
||||
num_attention_heads=num_heads,
|
||||
adm_in_channels=2048,
|
||||
device=offload_device,
|
||||
dtype=unet_dtype,
|
||||
operations=operations)
|
||||
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
preprocess_image = lambda a: a
|
||||
if canny_cnet:
|
||||
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
|
||||
elif depth_cnet:
|
||||
preprocess_image = lambda a: 1.0 - a
|
||||
|
||||
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||
return control
|
||||
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||
@@ -431,22 +554,49 @@ def load_controlnet_hunyuandit(controlnet_data):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_xlabs(sd):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
num_union_modes = 0
|
||||
union_cnet = "controlnet_mode_embedder.weight"
|
||||
if union_cnet in new_sd:
|
||||
num_union_modes = new_sd[union_cnet].shape[0]
|
||||
|
||||
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||
concat_mask = False
|
||||
if control_latent_channels == 17:
|
||||
concat_mask = True
|
||||
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Flux()
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
|
||||
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
controlnet_data = state_dict
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
return load_controlnet_hunyuandit(controlnet_data)
|
||||
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
return ControlLora(controlnet_data, model_options=model_options)
|
||||
|
||||
controlnet_config = None
|
||||
supported_inference_dtypes = None
|
||||
@@ -501,11 +651,18 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs(controlnet_data)
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
@@ -517,25 +674,36 @@ def load_controlnet(ckpt_path, model=None):
|
||||
elif key in controlnet_data:
|
||||
prefix = ""
|
||||
else:
|
||||
net = load_t2i_adapter(controlnet_data)
|
||||
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||
if net is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
logging.error("error could not detect control model type.")
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
controlnet_config = model_config.unet_config
|
||||
|
||||
unet_dtype = model_options.get("dtype", None)
|
||||
if unet_dtype is None:
|
||||
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||
|
||||
if supported_inference_dtypes is None:
|
||||
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
if supported_inference_dtypes is None:
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
else:
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||
|
||||
controlnet_config["operations"] = operations
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||
controlnet_config.pop("out_channels")
|
||||
@@ -571,22 +739,32 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||
if "global_average_pooling" not in model_options:
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
model_options["global_average_pooling"] = True
|
||||
|
||||
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||
if cnet is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
return cnet
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||
super().__init__(device)
|
||||
super().__init__()
|
||||
self.t2i_model = t2i_model
|
||||
self.channels_in = channels_in
|
||||
self.control_input = None
|
||||
self.compression_ratio = compression_ratio
|
||||
self.upscale_algorithm = upscale_algorithm
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.device = device
|
||||
|
||||
def scale_image_to(self, width, height):
|
||||
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||
@@ -594,10 +772,10 @@ class T2IAdapter(ControlBase):
|
||||
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
||||
return width, height
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
@@ -634,7 +812,7 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class NoiseScheduleVP:
|
||||
continuous_beta_0=0.1,
|
||||
continuous_beta_1=20.,
|
||||
):
|
||||
"""Create a wrapper class for the forward SDE (VP type).
|
||||
r"""Create a wrapper class for the forward SDE (VP type).
|
||||
|
||||
***
|
||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
import torch
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
|
||||
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||
|
||||
#Not 100% sure about this
|
||||
def manual_stochastic_round_to_float8(x, dtype):
|
||||
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||
elif dtype == torch.float8_e5m2:
|
||||
@@ -9,44 +19,35 @@ def manual_stochastic_round_to_float8(x, dtype):
|
||||
else:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
x = x.half()
|
||||
sign = torch.sign(x)
|
||||
abs_x = x.abs()
|
||||
sign = torch.where(abs_x == 0, 0, sign)
|
||||
|
||||
# Combine exponent calculation and clamping
|
||||
exponent = torch.clamp(
|
||||
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
|
||||
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||
0, 2**EXPONENT_BITS - 1
|
||||
)
|
||||
|
||||
# Combine mantissa calculation and rounding
|
||||
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
||||
# zero_mask = (abs_x == 0)
|
||||
# subnormal_mask = (exponent == 0) & (abs_x != 0)
|
||||
normal_mask = ~(exponent == 0)
|
||||
|
||||
mantissa_scaled = torch.where(
|
||||
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||
|
||||
sign *= torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
mantissa_floor = mantissa_scaled.floor()
|
||||
mantissa = torch.where(
|
||||
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
||||
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
||||
mantissa_floor / (2**MANTISSA_BITS)
|
||||
)
|
||||
result = torch.where(
|
||||
normal_mask,
|
||||
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
|
||||
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
|
||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
|
||||
result = torch.where(abs_x == 0, 0, result)
|
||||
return result.to(dtype=dtype)
|
||||
inf = torch.finfo(dtype)
|
||||
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
||||
return sign
|
||||
|
||||
|
||||
|
||||
def stochastic_rounding(value, dtype):
|
||||
def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float32:
|
||||
return value.to(dtype=torch.float32)
|
||||
if dtype == torch.float16:
|
||||
@@ -54,6 +55,13 @@ def stochastic_rounding(value, dtype):
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
return manual_stochastic_round_to_float8(value, dtype)
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
for i in range(0, value.shape[0], slice_size):
|
||||
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
690
comfy/hooks.py
Normal file
690
comfy/hooks.py
Normal file
@@ -0,0 +1,690 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import enum
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.sd import CLIP
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
class EnumHookMode(enum.Enum):
|
||||
MinVram = "minvram"
|
||||
MaxSpeed = "maxspeed"
|
||||
|
||||
class EnumHookType(enum.Enum):
|
||||
Weight = "weight"
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Callbacks = "callbacks"
|
||||
Wrappers = "wrappers"
|
||||
SetInjections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
Model = "model"
|
||||
Clip = "clip"
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
|
||||
# NOTE: this is an example of how the should_register function should look
|
||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return True
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||
self.hook_type = hook_type
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
self.hook_id = hook_id
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
self.reset()
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: Hook = subtype()
|
||||
c.hook_type = self.hook_type
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return self.custom_should_register(self, model, model_options, target, registered)
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: 'Hook'):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hook_ref)
|
||||
|
||||
class WeightHook(Hook):
|
||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||
super().__init__(hook_type=EnumHookType.Weight)
|
||||
self.weights: dict = None
|
||||
self.weights_clip: dict = None
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
return self._strength_model * self.strength
|
||||
|
||||
@property
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
weights = None
|
||||
if target == EnumWeightTarget.Model:
|
||||
strength = self._strength_model
|
||||
else:
|
||||
strength = self._strength_clip
|
||||
|
||||
if self.need_weight_init:
|
||||
key_map = {}
|
||||
if target == EnumWeightTarget.Model:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Model:
|
||||
weights = self.weights
|
||||
else:
|
||||
weights = self.weights_clip
|
||||
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WeightHook = super().clone(subtype)
|
||||
c.weights = self.weights
|
||||
c.weights_clip = self.weights_clip
|
||||
c.need_weight_init = self.need_weight_init
|
||||
c._strength_model = self._strength_model
|
||||
c._strength_clip = self._strength_clip
|
||||
return c
|
||||
|
||||
class PatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.Patch)
|
||||
self.patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: PatchHook = super().clone(subtype)
|
||||
c.patches = self.patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: ObjectPatchHook = super().clone(subtype)
|
||||
c.object_patches = self.object_patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class AddModelsHook(Hook):
|
||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
self.models = models
|
||||
self.append_when_same = True
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddModelsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
return True
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: SetInjectionsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.injections = self.injections.copy() if self.injections else self.injections
|
||||
return c
|
||||
|
||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
# TODO: add functionality
|
||||
pass
|
||||
|
||||
class HookGroup:
|
||||
def __init__(self):
|
||||
self.hooks: list[Hook] = []
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
|
||||
def contains(self, hook: Hook):
|
||||
return hook in self.hooks
|
||||
|
||||
def clone(self):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone_and_combine(self, other: 'HookGroup'):
|
||||
c = self.clone()
|
||||
if other is not None:
|
||||
for hook in other.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
hook_kf = hook_kf.clone()
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
def get_dict_repr(self):
|
||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||
for hook in self.hooks:
|
||||
with_type = d.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
# hooks should not have their schedules in a list of tuples
|
||||
all_ranges: list[tuple[float, float]] = []
|
||||
for range_kfs in scheduled_hooks.values():
|
||||
for t_range, keyframe in range_kfs:
|
||||
all_ranges.append(t_range)
|
||||
# turn list of ranges into boundaries
|
||||
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
|
||||
boundaries_set.add(0.0)
|
||||
boundaries = sorted(boundaries_set)
|
||||
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
|
||||
# with real ranges defined, give appropriate hooks w/ keyframes for each range
|
||||
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
|
||||
for t_range in real_ranges:
|
||||
hooks_schedule = []
|
||||
for hook, val in scheduled_hooks.items():
|
||||
keyframe = None
|
||||
# check if is a keyframe that works for the current t_range
|
||||
for stored_range, stored_kf in val:
|
||||
# if stored start is less than current end, then fits - give it assigned keyframe
|
||||
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
|
||||
keyframe = stored_kf
|
||||
break
|
||||
hooks_schedule.append((hook, keyframe))
|
||||
scheduled_keyframes.append((t_range, hooks_schedule))
|
||||
return scheduled_keyframes
|
||||
|
||||
def reset(self):
|
||||
for hook in self.hooks:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
actual.append(group)
|
||||
if len(actual) < require_count:
|
||||
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
|
||||
# if no hooks, then return None
|
||||
if len(actual) == 0:
|
||||
return None
|
||||
# if only 1 hook, just return itself without cloning
|
||||
elif len(actual) == 1:
|
||||
return actual[0]
|
||||
final_hook: HookGroup = None
|
||||
for hook in actual:
|
||||
if final_hook is None:
|
||||
final_hook = hook.clone()
|
||||
else:
|
||||
final_hook = final_hook.clone_and_combine(hook)
|
||||
return final_hook
|
||||
|
||||
|
||||
class HookKeyframe:
|
||||
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
|
||||
self.strength = strength
|
||||
# scheduling
|
||||
self.start_percent = float(start_percent)
|
||||
self.start_t = 999999999.9
|
||||
self.guarantee_steps = guarantee_steps
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframe(strength=self.strength,
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
c.start_t = self.start_t
|
||||
return c
|
||||
|
||||
class HookKeyframeGroup:
|
||||
def __init__(self):
|
||||
self.keyframes: list[HookKeyframe] = []
|
||||
self._current_keyframe: HookKeyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self._curr_t = -1.
|
||||
|
||||
# properties shadow those of HookWeightsKeyframe
|
||||
@property
|
||||
def strength(self):
|
||||
if self._current_keyframe is not None:
|
||||
return self._current_keyframe.strength
|
||||
return 1.0
|
||||
|
||||
def reset(self):
|
||||
self._current_keyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self.curr_t = -1.
|
||||
self._set_first_as_current()
|
||||
|
||||
def add(self, keyframe: HookKeyframe):
|
||||
# add to end of list, then sort
|
||||
self.keyframes.append(keyframe)
|
||||
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
|
||||
self._set_first_as_current()
|
||||
|
||||
def _set_first_as_current(self):
|
||||
if len(self.keyframes) > 0:
|
||||
self._current_keyframe = self.keyframes[0]
|
||||
else:
|
||||
self._current_keyframe = None
|
||||
|
||||
def has_index(self, index: int):
|
||||
return index >= 0 and index < len(self.keyframes)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.keyframes) == 0
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframeGroup()
|
||||
for keyframe in self.keyframes:
|
||||
c.keyframes.append(keyframe.clone())
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
||||
if self.is_empty():
|
||||
return False
|
||||
if curr_t == self._curr_t:
|
||||
return False
|
||||
prev_index = self._current_index
|
||||
prev_strength = self._current_strength
|
||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
||||
# if has next index, loop through and see if need to switch
|
||||
if self.has_index(self._current_index+1):
|
||||
for i in range(self._current_index+1, len(self.keyframes)):
|
||||
eval_c = self.keyframes[i]
|
||||
# check if start_t is greater or equal to curr_t
|
||||
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
|
||||
if eval_c.start_t >= curr_t:
|
||||
self._current_index = i
|
||||
self._current_strength = eval_c.strength
|
||||
self._current_keyframe = eval_c
|
||||
self._current_used_steps = 0
|
||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||
if self._current_keyframe.guarantee_steps > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
self._curr_t = curr_t
|
||||
# return True if keyframe changed, False if no change
|
||||
return prev_index != self._current_index and prev_strength != self._current_strength
|
||||
|
||||
|
||||
class InterpolationMethod:
|
||||
LINEAR = "linear"
|
||||
EASE_IN = "ease_in"
|
||||
EASE_OUT = "ease_out"
|
||||
EASE_IN_OUT = "ease_in_out"
|
||||
|
||||
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
||||
|
||||
@classmethod
|
||||
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
||||
diff = num_to - num_from
|
||||
if method == cls.LINEAR:
|
||||
weights = torch.linspace(num_from, num_to, length)
|
||||
elif method == cls.EASE_IN:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * np.power(index, 2) + num_from
|
||||
elif method == cls.EASE_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
||||
elif method == cls.EASE_IN_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
||||
else:
|
||||
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
||||
if reverse:
|
||||
weights = weights.flip(dims=(0,))
|
||||
return weights
|
||||
|
||||
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
if not objects:
|
||||
return objects
|
||||
elif len(objects) <= 1:
|
||||
return [x for x in objects]
|
||||
# now that we know we have to sort, do it following these rules:
|
||||
# a) if objects have same value of attribute, maintain their relative order
|
||||
# b) perform sorting of the groups of objects with same attributes
|
||||
unique_attrs = {}
|
||||
for o in objects:
|
||||
val_attr = getattr(o, attr)
|
||||
attr_list: list = unique_attrs.get(val_attr, list())
|
||||
attr_list.append(o)
|
||||
if val_attr not in unique_attrs:
|
||||
unique_attrs[val_attr] = attr_list
|
||||
# now that we have the unique attr values grouped together in relative order, sort them by key
|
||||
sorted_attrs = dict(sorted(unique_attrs.items()))
|
||||
# now flatten out the dict into a list to return
|
||||
sorted_list = []
|
||||
for object_list in sorted_attrs.values():
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
hook.weights = lora
|
||||
return hook_group
|
||||
|
||||
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
patches_model = None
|
||||
patches_clip = None
|
||||
if weights_model is not None:
|
||||
patches_model = {}
|
||||
for key in weights_model:
|
||||
patches_model[key] = ("model_as_lora", (weights_model[key],))
|
||||
if weights_clip is not None:
|
||||
patches_clip = {}
|
||||
for key in weights_clip:
|
||||
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
|
||||
hook.weights = patches_model
|
||||
hook.weights_clip = patches_clip
|
||||
hook.need_weight_init = False
|
||||
return hook_group
|
||||
|
||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||
if model is None:
|
||||
return None
|
||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||
if discard_model_sampling:
|
||||
# do not include ANY model_sampling components of the model that should act as a patch
|
||||
for key in list(patches_model.keys()):
|
||||
if key.startswith("model_sampling"):
|
||||
patches_model.pop(key, None)
|
||||
return patches_model
|
||||
|
||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||
strength_model: float, strength_clip: float):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook()
|
||||
hook_group.add(hook)
|
||||
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
print(f"NOT LOADED {x}")
|
||||
return (new_modelpatcher, new_clip, hook_group)
|
||||
|
||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||
hooks_key = 'hooks'
|
||||
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
|
||||
if hooks_key not in values:
|
||||
return
|
||||
if hooks_key not in c_dict:
|
||||
hooks_value = values.get(hooks_key, None)
|
||||
if hooks_value is not None:
|
||||
c_dict[hooks_key] = hooks_value
|
||||
return
|
||||
# otherwise, need to combine with minimum duplication via cache
|
||||
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
|
||||
cached_hooks = cache.get(hooks_tuple, None)
|
||||
if cached_hooks is None:
|
||||
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
|
||||
cache[hooks_tuple] = new_hooks
|
||||
c_dict[hooks_key] = new_hooks
|
||||
else:
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
if append_hooks and k == 'hooks':
|
||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||
else:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||
if hooks is None:
|
||||
return cond
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||
|
||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||
if timestep_range is None:
|
||||
return cond
|
||||
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
|
||||
"end_percent": timestep_range[1]})
|
||||
|
||||
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
|
||||
if mask is None:
|
||||
return cond
|
||||
set_area_to_bounds = False
|
||||
if set_cond_area != 'default':
|
||||
set_area_to_bounds = True
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
return conditioning_set_values(cond, {'mask': mask,
|
||||
'set_area_to_bounds': set_area_to_bounds,
|
||||
'mask_strength': strength})
|
||||
|
||||
def combine_conditioning(conds: list):
|
||||
combined_conds = []
|
||||
for cond in conds:
|
||||
combined_conds.extend(cond)
|
||||
return combined_conds
|
||||
|
||||
def combine_with_new_conds(conds: list, new_conds: list):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
final_conds = []
|
||||
for c in conds:
|
||||
# first, apply lora_hook to conditioning, if provided
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to conditioning
|
||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||
# apply timesteps, if present
|
||||
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
|
||||
# finally, apply mask to conditioning and store
|
||||
final_conds.append(c)
|
||||
return final_conds
|
||||
|
||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, masked_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to new conditioning, if provided
|
||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||
# apply timesteps, if present
|
||||
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, masked_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||
new_c = conditioning_set_values(new_c, {'default': True})
|
||||
# apply timesteps, if present
|
||||
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||
epsilon = 1e-5 # avoid log(0)
|
||||
x = torch.linspace(0, 1, n, device=device)
|
||||
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||
sigmas = clamp(torch.exp(lmb))
|
||||
return sigmas
|
||||
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
@@ -153,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -162,14 +175,42 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
if sigma_down == 0:
|
||||
x = denoised
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i + 1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if eta > 0:
|
||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
@@ -244,6 +285,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -270,6 +314,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
@@ -1069,7 +1145,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
@@ -1096,8 +1171,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
x = denoised + d * sigma_down
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
old_uncond_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||
old_uncond_denoised = uncond_denoised
|
||||
return x
|
||||
|
||||
@@ -4,6 +4,7 @@ class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 0.3920, 0.4054, 0.4549],
|
||||
[-0.2634, -0.0196, 0.0653],
|
||||
[ 0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
[ 0.3651, 0.4232, 0.4341],
|
||||
[-0.2533, -0.0042, 0.1068],
|
||||
[ 0.1076, 0.1111, -0.0362],
|
||||
[-0.3165, -0.2492, -0.2188]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
|
||||
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
class SDXL_Playground_2_5(LatentFormat):
|
||||
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[ 0.0028, 0.0312, 0.0650],
|
||||
[ 0.1848, 0.0762, 0.0360],
|
||||
[ 0.0944, 0.0360, 0.0889],
|
||||
[ 0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[ 0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[ 0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[ 0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259]
|
||||
[-0.0922, -0.0175, 0.0749],
|
||||
[ 0.0311, 0.0633, 0.0954],
|
||||
[ 0.1994, 0.0927, 0.0458],
|
||||
[ 0.0856, 0.0339, 0.0902],
|
||||
[ 0.0587, 0.0272, -0.0496],
|
||||
[-0.0006, 0.1104, 0.0309],
|
||||
[ 0.0978, 0.0306, 0.0427],
|
||||
[-0.0042, 0.1038, 0.1358],
|
||||
[-0.0194, 0.0020, 0.0669],
|
||||
[-0.0488, 0.0130, -0.0268],
|
||||
[ 0.0922, 0.0988, 0.0951],
|
||||
[-0.0278, 0.0524, -0.0542],
|
||||
[ 0.0332, 0.0456, 0.0895],
|
||||
[-0.0069, -0.0030, -0.0810],
|
||||
[-0.0596, -0.0465, -0.0293],
|
||||
[-0.1448, -0.1463, -0.1189]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
|
||||
self.taesd_decoder_name = "taesd3_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -146,23 +150,24 @@ class Flux(SD3):
|
||||
self.scale_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0404, 0.0159, 0.0609],
|
||||
[ 0.0043, 0.0298, 0.0850],
|
||||
[ 0.0328, -0.0749, -0.0503],
|
||||
[-0.0245, 0.0085, 0.0549],
|
||||
[ 0.0966, 0.0894, 0.0530],
|
||||
[ 0.0035, 0.0399, 0.0123],
|
||||
[ 0.0583, 0.1184, 0.1262],
|
||||
[-0.0191, -0.0206, -0.0306],
|
||||
[-0.0324, 0.0055, 0.1001],
|
||||
[ 0.0955, 0.0659, -0.0545],
|
||||
[-0.0504, 0.0231, -0.0013],
|
||||
[ 0.0500, -0.0008, -0.0088],
|
||||
[ 0.0982, 0.0941, 0.0976],
|
||||
[-0.1233, -0.0280, -0.0897],
|
||||
[-0.0005, -0.0530, -0.0020],
|
||||
[-0.1273, -0.0932, -0.0680]
|
||||
[-0.0346, 0.0244, 0.0681],
|
||||
[ 0.0034, 0.0210, 0.0687],
|
||||
[ 0.0275, -0.0668, -0.0433],
|
||||
[-0.0174, 0.0160, 0.0617],
|
||||
[ 0.0859, 0.0721, 0.0329],
|
||||
[ 0.0004, 0.0383, 0.0115],
|
||||
[ 0.0405, 0.0861, 0.0915],
|
||||
[-0.0236, -0.0185, -0.0259],
|
||||
[-0.0245, 0.0250, 0.1180],
|
||||
[ 0.1008, 0.0755, -0.0421],
|
||||
[-0.0515, 0.0201, 0.0011],
|
||||
[ 0.0428, -0.0012, -0.0036],
|
||||
[ 0.0817, 0.0765, 0.0749],
|
||||
[-0.1264, -0.0522, -0.1103],
|
||||
[-0.0280, -0.0881, -0.0499],
|
||||
[-0.1262, -0.0982, -0.0778]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.taesd_decoder_name = "taef1_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -170,3 +175,180 @@ class Flux(SD3):
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
|
||||
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
|
||||
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
|
||||
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
|
||||
0.959253732819592, 0.8244560132752793, 0.917259975397747,
|
||||
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0069, -0.0045, 0.0018],
|
||||
[ 0.0154, -0.0692, -0.0274],
|
||||
[ 0.0333, 0.0019, 0.0206],
|
||||
[-0.1390, 0.0628, 0.1678],
|
||||
[-0.0725, 0.0134, -0.1898],
|
||||
[ 0.0074, -0.0270, -0.0209],
|
||||
[-0.0176, -0.0277, -0.0221],
|
||||
[ 0.5294, 0.5204, 0.3852],
|
||||
[-0.0326, -0.0446, -0.0143],
|
||||
[-0.0659, 0.0153, -0.0153],
|
||||
[ 0.0185, -0.0217, 0.0014],
|
||||
[-0.0396, -0.0495, -0.0281]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
|
||||
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
|
||||
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
|
||||
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
|
||||
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
|
||||
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
|
||||
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
|
||||
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
|
||||
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
|
||||
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
|
||||
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
|
||||
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
|
||||
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
|
||||
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
|
||||
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
|
||||
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
|
||||
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
|
||||
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
|
||||
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
|
||||
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
|
||||
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
|
||||
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
|
||||
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
|
||||
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
|
||||
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
|
||||
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
|
||||
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
|
||||
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
|
||||
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
|
||||
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
|
||||
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
|
||||
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
|
||||
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
|
||||
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
|
||||
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
|
||||
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
|
||||
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
|
||||
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
|
||||
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
|
||||
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
|
||||
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
|
||||
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
|
||||
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
|
||||
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
|
||||
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
|
||||
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
|
||||
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
|
||||
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
|
||||
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
|
||||
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
|
||||
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
|
||||
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
|
||||
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
|
||||
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
|
||||
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
|
||||
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
|
||||
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
|
||||
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
|
||||
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
|
||||
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
|
||||
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
|
||||
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
|
||||
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
|
||||
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
|
||||
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
|
||||
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
|
||||
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
|
||||
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
|
||||
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
|
||||
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
|
||||
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
|
||||
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
|
||||
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
|
||||
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
|
||||
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
|
||||
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
|
||||
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
|
||||
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
|
||||
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
|
||||
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
|
||||
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
|
||||
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
|
||||
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
|
||||
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
|
||||
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
|
||||
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
|
||||
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
|
||||
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
|
||||
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
|
||||
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
|
||||
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
|
||||
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
|
||||
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
|
||||
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
|
||||
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
|
||||
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
|
||||
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
|
||||
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
|
||||
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
|
||||
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
|
||||
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
|
||||
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
|
||||
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
|
||||
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
|
||||
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
|
||||
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
|
||||
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
|
||||
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
|
||||
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
|
||||
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
|
||||
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
|
||||
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
|
||||
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
|
||||
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
|
||||
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
|
||||
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
|
||||
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
|
||||
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
|
||||
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
|
||||
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
|
||||
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
|
||||
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
|
||||
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
|
||||
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
|
||||
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
|
||||
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
|
||||
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
@@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
@@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
@@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
|
||||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
|
||||
@@ -437,7 +437,8 @@ class MMDiT(nn.Module):
|
||||
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
@@ -458,15 +459,36 @@ class MMDiT(nn.Module):
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
if len(self.double_layers) > 0:
|
||||
for layer in self.double_layers:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.double_layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
cx = torch.cat([c, x], dim=1)
|
||||
for layer in self.single_layers:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.single_layers):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@@ -1,8 +1,27 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
205
comfy/ldm/flux/controlnet.py
Normal file
205
comfy/ldm/flux/controlnet.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
#modified to support different types of flux controlnets
|
||||
|
||||
import torch
|
||||
import math
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class MistolineCondDownsamplBlock(nn.Module):
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
class MistolineControlnetBlock(nn.Module):
|
||||
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.linear(x))
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
self.main_model_double = 19
|
||||
self.main_model_single = 38
|
||||
|
||||
self.mistoline = mistoline
|
||||
# add ControlNet blocks
|
||||
if self.mistoline:
|
||||
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
self.controlnet_blocks.append(control_block())
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth_single_blocks):
|
||||
self.controlnet_single_blocks.append(control_block())
|
||||
|
||||
self.num_union_modes = num_union_modes
|
||||
self.controlnet_mode_embedder = None
|
||||
if self.num_union_modes > 0:
|
||||
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.latent_input = latent_input
|
||||
if control_latent_channels is None:
|
||||
control_latent_channels = self.in_channels
|
||||
else:
|
||||
control_latent_channels *= 2 * 2 #patch size
|
||||
|
||||
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if not self.latent_input:
|
||||
if self.mistoline:
|
||||
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control_type: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||
txt = torch.cat([control_cond, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
controlnet_double = ()
|
||||
|
||||
for i in range(len(self.double_blocks)):
|
||||
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
controlnet_single = ()
|
||||
|
||||
for i in range(len(self.single_blocks)):
|
||||
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||
if self.latent_input:
|
||||
out_input = ()
|
||||
for x in controlnet_double:
|
||||
out_input += (x,) * repeat
|
||||
else:
|
||||
out_input = (controlnet_double * repeat)
|
||||
|
||||
out = {"input": out_input[:self.main_model_double]}
|
||||
if len(controlnet_single) > 0:
|
||||
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||
out_output = ()
|
||||
if self.latent_input:
|
||||
for x in controlnet_single:
|
||||
out_output += (x,) * repeat
|
||||
else:
|
||||
out_output = (controlnet_single * repeat)
|
||||
out["output"] = out_output[:self.main_model_single]
|
||||
return out
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
elif self.mistoline:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_cond_block(hint)
|
||||
else:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_hint_block(hint)
|
||||
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
||||
@@ -1,104 +0,0 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
# controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
block_res_samples = block_res_samples + (img,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
return {"input": (controlnet_block_res_samples * 10)[:19]}
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
hint = hint * 2.0 - 1.0
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
||||
@@ -6,6 +6,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
@@ -29,6 +30,7 @@ class FluxParams:
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
@@ -43,8 +45,9 @@ class Flux(nn.Module):
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -96,7 +99,9 @@ class Flux(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
@@ -108,14 +113,25 @@ class Flux(nn.Module):
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y)
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -127,7 +143,16 @@ class Flux(nn.Module):
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -141,9 +166,9 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
@@ -151,10 +176,10 @@ class Flux(nn.Module):
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
||||
25
comfy/ldm/flux/redux.py
Normal file
25
comfy/ldm/flux/redux.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
||||
559
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
559
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
@@ -0,0 +1,559 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
from .rope_mixed import (
|
||||
compute_mixed_rotation,
|
||||
create_position_matrix,
|
||||
)
|
||||
from .temporal_rope import apply_rotary_emb_qk_real
|
||||
from .utils import (
|
||||
AttentionPool,
|
||||
modulate,
|
||||
)
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.ops
|
||||
|
||||
|
||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||
# Normalize and modulate
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
||||
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||
|
||||
return x_modulated
|
||||
|
||||
|
||||
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||
# Apply tanh to gate
|
||||
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||
|
||||
# Normalize and apply gated scaling
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
||||
|
||||
# Apply residual connection
|
||||
output = x + x_normed
|
||||
|
||||
return output
|
||||
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_x: int,
|
||||
dim_y: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
update_y: bool = True,
|
||||
out_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
softmax_scale: Optional[float] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_x = dim_x
|
||||
self.dim_y = dim_y
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_x // num_heads
|
||||
self.attn_drop = attn_drop
|
||||
self.update_y = update_y
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.softmax_scale = softmax_scale
|
||||
if dim_x % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||
)
|
||||
|
||||
# Input layers.
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||
# Project text features to match visual features (dim_y -> dim_x)
|
||||
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||
self.proj_y = (
|
||||
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
|
||||
if update_y
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (B, N, dim_x)
|
||||
y: torch.Tensor, # (B, L, dim_y)
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
rope_sin = rope_rotation.get("rope_sin")
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
|
||||
# Process visual features
|
||||
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
# assert qkv_x.dtype == torch.bfloat16
|
||||
# qkv_x = all_to_all_collect_tokens(
|
||||
# qkv_x, self.num_heads
|
||||
# ) # (3, B, N, local_h, head_dim)
|
||||
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||
o[:, :y.shape[1]] = y
|
||||
|
||||
y = self.proj_y(o)
|
||||
# print("ox", x)
|
||||
# print("oy", y)
|
||||
return x, y
|
||||
|
||||
|
||||
class AsymmetricJointBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size_x: int,
|
||||
hidden_size_y: int,
|
||||
num_heads: int,
|
||||
*,
|
||||
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||
update_y: bool = True, # Whether to update text tokens in this block.
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.update_y = update_y
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
|
||||
if self.update_y:
|
||||
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
|
||||
else:
|
||||
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
|
||||
|
||||
# Self-attention:
|
||||
self.attn = AsymmetricAttention(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads=num_heads,
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||
self.mlp_x = FeedForward(
|
||||
in_features=hidden_size_x,
|
||||
hidden_size=mlp_hidden_dim_x,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# MLP for text not needed in last block.
|
||||
if self.update_y:
|
||||
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
||||
self.mlp_y = FeedForward(
|
||||
in_features=hidden_size_y,
|
||||
hidden_size=mlp_hidden_dim_y,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
|
||||
Args:
|
||||
x: (B, N, dim) tensor of visual tokens
|
||||
c: (B, dim) tensor of conditioned features
|
||||
y: (B, L, dim) tensor of text tokens
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim) tensor of visual tokens after block
|
||||
y: (B, L, dim) tensor of text tokens after block
|
||||
"""
|
||||
N = x.size(1)
|
||||
|
||||
c = F.silu(c)
|
||||
mod_x = self.mod_x(c)
|
||||
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
||||
|
||||
mod_y = self.mod_y(c)
|
||||
if self.update_y:
|
||||
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||
else:
|
||||
scale_msa_y = mod_y
|
||||
|
||||
# Self-attention block.
|
||||
x_attn, y_attn = self.attn(
|
||||
x,
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
assert x_attn.size(1) == N
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
if self.update_y:
|
||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||
|
||||
# MLP block.
|
||||
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||
if self.update_y:
|
||||
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||
|
||||
return x, y
|
||||
|
||||
def ff_block_x(self, x, scale_x, gate_x):
|
||||
x_mod = modulated_rmsnorm(x, scale_x)
|
||||
x_res = self.mlp_x(x_mod)
|
||||
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||
return x
|
||||
|
||||
def ff_block_y(self, y, scale_y, gate_y):
|
||||
y_mod = modulated_rmsnorm(y, scale_y)
|
||||
y_res = self.mlp_y(y_mod)
|
||||
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||
return y
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
out_channels,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
|
||||
)
|
||||
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = F.silu(c)
|
||||
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class AsymmDiTJoint(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
|
||||
Ingests text embeddings instead of a label.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size_x=1152,
|
||||
hidden_size_y=1152,
|
||||
depth=48,
|
||||
num_heads=16,
|
||||
mlp_ratio_x=8.0,
|
||||
mlp_ratio_y=4.0,
|
||||
use_t5: bool = False,
|
||||
t5_feat_dim: int = 4096,
|
||||
t5_token_length: int = 256,
|
||||
learn_sigma=True,
|
||||
patch_embed_bias: bool = True,
|
||||
timestep_mlp_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
timestep_scale: Optional[float] = None,
|
||||
use_extended_posenc: bool = False,
|
||||
posenc_preserve_area: bool = False,
|
||||
rope_theta: float = 10000.0,
|
||||
image_model=None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.head_dim = (
|
||||
hidden_size_x // num_heads
|
||||
) # Head dimension and count is determined by visual.
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.use_extended_posenc = use_extended_posenc
|
||||
self.posenc_preserve_area = posenc_preserve_area
|
||||
self.use_t5 = use_t5
|
||||
self.t5_token_length = t5_token_length
|
||||
self.t5_feat_dim = t5_feat_dim
|
||||
self.rope_theta = (
|
||||
rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
||||
)
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size_x,
|
||||
bias=patch_embed_bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
# Conditionings
|
||||
# Timestep
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if self.use_t5:
|
||||
# Caption Pooling (T5)
|
||||
self.t5_y_embedder = AttentionPool(
|
||||
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# Dense Embedding Projection (T5)
|
||||
self.t5_yproj = operations.Linear(
|
||||
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Initialize pos_frequencies as an empty parameter.
|
||||
self.pos_frequencies = nn.Parameter(
|
||||
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
assert not self.attend_to_padding
|
||||
|
||||
# for depth 48:
|
||||
# b = 0: AsymmetricJointBlock, update_y=True
|
||||
# b = 1: AsymmetricJointBlock, update_y=True
|
||||
# ...
|
||||
# b = 46: AsymmetricJointBlock, update_y=True
|
||||
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
||||
blocks = []
|
||||
for b in range(depth):
|
||||
# Joint multi-modal block
|
||||
update_y = b < depth - 1
|
||||
block = AsymmetricJointBlock(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads,
|
||||
mlp_ratio_x=mlp_ratio_x,
|
||||
mlp_ratio_y=mlp_ratio_y,
|
||||
update_y=update_y,
|
||||
attend_to_padding=attend_to_padding,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C=12, T, H, W) tensor of visual tokens
|
||||
|
||||
Returns:
|
||||
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
||||
"""
|
||||
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
t5_feat: torch.Tensor,
|
||||
t5_mask: torch.Tensor,
|
||||
):
|
||||
"""Prepare input and conditioning embeddings."""
|
||||
# Visual patch embeddings with positional encoding.
|
||||
T, H, W = x.shape[-3:]
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||
assert x.ndim == 3
|
||||
B = x.size(0)
|
||||
|
||||
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
N = T * pH * pW
|
||||
assert x.size(1) == N
|
||||
pos = create_position_matrix(
|
||||
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||
) # (N, 3)
|
||||
rope_cos, rope_sin = compute_mixed_rotation(
|
||||
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||
) # Each are (N, num_heads, dim // 2)
|
||||
|
||||
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||
|
||||
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||
|
||||
c = c_t + t5_y_pool
|
||||
|
||||
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
||||
|
||||
return x, c, y_feat, rope_cos, rope_sin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: List[torch.Tensor],
|
||||
attention_mask: List[torch.Tensor],
|
||||
num_tokens=256,
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
control=None, transformer_options={}, **kwargs
|
||||
):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
y_feat = context
|
||||
y_mask = attention_mask
|
||||
sigma = timestep
|
||||
"""Forward pass of DiT.
|
||||
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
sigma: (B,) tensor of noise standard deviations
|
||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||
x, sigma, y_feat, y_mask
|
||||
)
|
||||
del y_mask
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(
|
||||
args["img"],
|
||||
args["vec"],
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||
T=T,
|
||||
hp=H // self.patch_size,
|
||||
wp=W // self.patch_size,
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
c=self.out_channels,
|
||||
)
|
||||
|
||||
return -x
|
||||
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from itertools import repeat
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
frequency_embedding_size: int = 256,
|
||||
*,
|
||||
bias: bool = True,
|
||||
timestep_scale: Optional[float] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.timestep_scale = timestep_scale
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
freqs.mul_(-math.log(max_period) / half).exp_()
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, out_dtype):
|
||||
if self.timestep_scale is not None:
|
||||
t = t * self.timestep_scale
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_size: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
# keep parameter count and computation constant compared to standard FFN
|
||||
hidden_size = int(2 * hidden_size / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
||||
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||
x = self.w2(F.silu(x) * gate)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.flatten = flatten
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = operations.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert norm_layer is None
|
||||
self.norm = (
|
||||
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, _C, T, H, W = x.shape
|
||||
if not self.dynamic_img_pad:
|
||||
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
else:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||
x = self.proj(x)
|
||||
|
||||
# Flatten temporal and spatial dimensions.
|
||||
if not self.flatten:
|
||||
raise NotImplementedError("Must flatten output.")
|
||||
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
|
||||
# import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def centers(start: float, stop, num, dtype=None, device=None):
|
||||
"""linspace through bin centers.
|
||||
|
||||
Args:
|
||||
start (float): Start of the range.
|
||||
stop (float): End of the range.
|
||||
num (int): Number of points.
|
||||
dtype (torch.dtype): Data type of the points.
|
||||
device (torch.device): Device of the points.
|
||||
|
||||
Returns:
|
||||
centers (Tensor): Centers of the bins. Shape: (num,).
|
||||
"""
|
||||
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
|
||||
# @functools.lru_cache(maxsize=1)
|
||||
def create_position_matrix(
|
||||
T: int,
|
||||
pH: int,
|
||||
pW: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
target_area: float = 36864,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
T: int - Temporal dimension
|
||||
pH: int - Height dimension after patchify
|
||||
pW: int - Width dimension after patchify
|
||||
|
||||
Returns:
|
||||
pos: [T * pH * pW, 3] - position matrix
|
||||
"""
|
||||
# Create 1D tensors for each dimension
|
||||
t = torch.arange(T, dtype=dtype)
|
||||
|
||||
# Positionally interpolate to area 36864.
|
||||
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
||||
# This automatically scales rope positions when the resolution changes.
|
||||
# We use a large target area so the model is more sensitive
|
||||
# to changes in the learned pos_frequencies matrix.
|
||||
scale = math.sqrt(target_area / (pW * pH))
|
||||
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
||||
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
||||
|
||||
# Use meshgrid to create 3D grids
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||
|
||||
# Stack and reshape the grids.
|
||||
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||
pos = pos.to(dtype=dtype, device=device)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def compute_mixed_rotation(
|
||||
freqs: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
||||
|
||||
Args:
|
||||
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
||||
pos: [N, 3] - position of each token
|
||||
num_heads: int
|
||||
|
||||
Returns:
|
||||
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||
"""
|
||||
assert freqs.ndim == 3
|
||||
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
||||
freqs_cos = torch.cos(freqs_sum)
|
||||
freqs_sin = torch.sin(freqs_sum)
|
||||
return freqs_cos, freqs_sin
|
||||
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
|
||||
# Based on Llama3 Implementation.
|
||||
import torch
|
||||
|
||||
|
||||
def apply_rotary_emb_qk_real(
|
||||
xqk: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
||||
|
||||
Args:
|
||||
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
||||
Can be either just query or just key, or both stacked along some batch or * dim.
|
||||
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
||||
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||
"""
|
||||
# Split the last dimension into even and odd parts
|
||||
xqk_even = xqk[..., 0::2]
|
||||
xqk_odd = xqk[..., 1::2]
|
||||
|
||||
# Apply rotation
|
||||
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
||||
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||
|
||||
# Interleave the results back into the original shape
|
||||
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||
return out
|
||||
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||
"""
|
||||
Pool tokens in x using mask.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Args:
|
||||
x: (B, L, D) tensor of tokens.
|
||||
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
Returns:
|
||||
pooled: (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
spatial_dim (int): Number of tokens in sequence length.
|
||||
embed_dim (int): Dimensionality of input tokens.
|
||||
num_heads (int): Number of attention heads.
|
||||
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
|
||||
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
D = x.size(2)
|
||||
|
||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||
|
||||
# Average non-padding token features. These will be used as the query.
|
||||
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||
|
||||
# Concat pooled features to input sequence.
|
||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||
|
||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||
q = self.to_q(x[:, 0]) # (B, D)
|
||||
|
||||
# Extract heads.
|
||||
head_dim = D // self.num_heads
|
||||
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
711
comfy/ldm/genmo/vae/model.py
Normal file
711
comfy/ldm/genmo/vae/model.py
Normal file
@@ -0,0 +1,711 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
class GroupNormSpatial(ops.GroupNorm):
|
||||
"""
|
||||
GroupNorm applied per-frame.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
|
||||
B, C, T, H, W = x.shape
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
# Run group norm in chunks.
|
||||
output = torch.empty_like(x)
|
||||
for b in range(0, B * T, chunk_size):
|
||||
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
||||
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||
|
||||
class PConv3d(ops.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
causal: bool = True,
|
||||
context_parallel: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.causal = causal
|
||||
self.context_parallel = context_parallel
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
stride = cast_tuple(stride, 3)
|
||||
height_pad = (kernel_size[1] - 1) // 2
|
||||
width_pad = (kernel_size[2] - 1) // 2
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(1, 1, 1),
|
||||
padding=(0, height_pad, width_pad),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Compute padding amounts.
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if self.causal:
|
||||
pad_front = context_size
|
||||
pad_back = 0
|
||||
else:
|
||||
pad_front = context_size // 2
|
||||
pad_back = context_size - pad_front
|
||||
|
||||
# Apply padding.
|
||||
assert self.padding_mode == "replicate" # DEBUG
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class Conv1x1(ops.Linear):
|
||||
"""*1x1 Conv implemented with a linear layer."""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
|
||||
super().__init__(in_features, out_features, *args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, *] or [B, *, C].
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
|
||||
"""
|
||||
x = x.movedim(1, -1)
|
||||
x = super().forward(x)
|
||||
x = x.movedim(-1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceTime(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
temporal_expansion: int,
|
||||
spatial_expansion: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# When printed, this module should show the temporal and spatial expansion factors.
|
||||
def extra_repr(self):
|
||||
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
|
||||
Returns:
|
||||
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
|
||||
"""
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
|
||||
st=self.temporal_expansion,
|
||||
sh=self.spatial_expansion,
|
||||
sw=self.spatial_expansion,
|
||||
)
|
||||
|
||||
# cp_rank, _ = cp.get_cp_rank_size()
|
||||
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||
# Drop the first self.temporal_expansion - 1 frames.
|
||||
# This is because we always want the 3x3x3 conv filter to only apply
|
||||
# to the first frame, and the first frame doesn't need to be repeated.
|
||||
assert all(x.shape)
|
||||
x = x[:, :, self.temporal_expansion - 1 :]
|
||||
assert all(x.shape)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def norm_fn(
|
||||
in_channels: int,
|
||||
affine: bool = True,
|
||||
):
|
||||
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block that preserves the spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
*,
|
||||
affine: bool = True,
|
||||
attn_block: Optional[nn.Module] = None,
|
||||
causal: bool = True,
|
||||
prune_bottleneck: bool = False,
|
||||
padding_mode: str,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
assert causal
|
||||
self.stack = nn.Sequential(
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
)
|
||||
|
||||
self.attn_block = attn_block if attn_block else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
residual = x
|
||||
x = self.stack(x)
|
||||
x = x + residual
|
||||
del residual
|
||||
|
||||
return self.attn_block(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
out_bias: bool = True,
|
||||
qk_norm: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute temporal self-attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
chunk_size: Chunk size for large tensors.
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
if T == 1:
|
||||
# No attention for single frame.
|
||||
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||
qkv = self.qkv(x)
|
||||
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||
x = self.out(x)
|
||||
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||
|
||||
# 1D temporal attention.
|
||||
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||
# Output: x with shape [B, num_heads, t, head_dim]
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
|
||||
|
||||
if self.qk_norm:
|
||||
q = F.normalize(q, p=2, dim=-1)
|
||||
k = F.normalize(k, p=2, dim=-1)
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||
|
||||
assert x.size(0) == q.size(0)
|
||||
|
||||
x = self.out(x)
|
||||
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
**attn_kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = norm_fn(dim)
|
||||
self.attn = Attention(dim, **attn_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.attn(self.norm(x))
|
||||
|
||||
|
||||
class CausalUpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
*,
|
||||
temporal_expansion: int = 2,
|
||||
spatial_expansion: int = 2,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = []
|
||||
for _ in range(num_res_blocks):
|
||||
blocks.append(block_fn(in_channels, **block_kwargs))
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# Change channels in the final convolution layer.
|
||||
self.proj = Conv1x1(
|
||||
in_channels,
|
||||
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||
)
|
||||
|
||||
self.d2st = DepthToSpaceTime(
|
||||
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
x = self.proj(x)
|
||||
x = self.d2st(x)
|
||||
return x
|
||||
|
||||
|
||||
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||
attn_block = AttentionBlock(channels) if has_attention else None
|
||||
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||
|
||||
|
||||
class DownsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks,
|
||||
*,
|
||||
temporal_reduction=2,
|
||||
spatial_reduction=2,
|
||||
**block_kwargs,
|
||||
):
|
||||
"""
|
||||
Downsample block for the VAE encoder.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
num_res_blocks: Number of residual blocks.
|
||||
temporal_reduction: Temporal reduction factor.
|
||||
spatial_reduction: Spatial reduction factor.
|
||||
"""
|
||||
super().__init__()
|
||||
layers = []
|
||||
|
||||
# Change the channel count in the strided convolution.
|
||||
# This lets the ResBlock have uniform channel count,
|
||||
# as in ConvNeXt.
|
||||
assert in_channels != out_channels
|
||||
layers.append(
|
||||
PConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
# First layer in each block always uses replicate padding
|
||||
padding_mode="replicate",
|
||||
bias=block_kwargs["bias"],
|
||||
)
|
||||
)
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(block_fn(out_channels, **block_kwargs))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||
num_freqs = (stop - start) // step
|
||||
assert inputs.ndim == 5
|
||||
C = inputs.size(1)
|
||||
|
||||
# Create Base 2 Fourier features.
|
||||
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
|
||||
assert num_freqs == len(freqs)
|
||||
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
|
||||
C = inputs.shape[1]
|
||||
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
|
||||
|
||||
# Interleaved repeat of input channels to match w.
|
||||
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
inputs,
|
||||
torch.sin(h),
|
||||
torch.cos(h),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
|
||||
super().__init__()
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Add Fourier features to inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input tensor. Shape: [B, C, T, H, W]
|
||||
|
||||
Returns:
|
||||
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
|
||||
"""
|
||||
return add_fourier_features(inputs, self.start, self.stop, self.step)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int = 3,
|
||||
latent_dim: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
temporal_expansions: Optional[List[int]] = None,
|
||||
spatial_expansions: Optional[List[int]] = None,
|
||||
has_attention: List[bool],
|
||||
output_norm: bool = True,
|
||||
nonlinearity: str = "silu",
|
||||
output_nonlinearity: str = "silu",
|
||||
causal: bool = True,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_channels = latent_dim
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.output_nonlinearity = output_nonlinearity
|
||||
assert nonlinearity == "silu"
|
||||
assert causal
|
||||
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
self.num_up_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||
|
||||
blocks = []
|
||||
|
||||
first_block = [
|
||||
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
] # Input layer.
|
||||
# First set of blocks preserve channel count.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
first_block.append(
|
||||
block_fn(
|
||||
ch[-1],
|
||||
has_attention=has_attention[-1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*first_block))
|
||||
|
||||
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
|
||||
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
|
||||
|
||||
upsample_block_fn = CausalUpsampleBlock
|
||||
|
||||
for i in range(self.num_up_blocks):
|
||||
block = upsample_block_fn(
|
||||
ch[-i - 1],
|
||||
ch[-i - 2],
|
||||
num_res_blocks=num_res_blocks[-i - 2],
|
||||
has_attention=has_attention[-i - 2],
|
||||
temporal_expansion=temporal_expansions[-i - 1],
|
||||
spatial_expansion=spatial_expansions[-i - 1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
assert not output_norm
|
||||
|
||||
# Last block. Preserve channel count.
|
||||
last_block = []
|
||||
for _ in range(num_res_blocks[0]):
|
||||
last_block.append(
|
||||
block_fn(
|
||||
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*last_block))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.output_proj = Conv1x1(ch[0], out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
|
||||
|
||||
Returns:
|
||||
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
|
||||
T + 1 = (t - 1) * 4.
|
||||
H = h * 16, W = w * 16.
|
||||
"""
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.output_nonlinearity == "silu":
|
||||
x = F.silu(x, inplace=not self.training)
|
||||
else:
|
||||
assert (
|
||||
not self.output_nonlinearity
|
||||
) # StyleGAN3 omits the to-RGB nonlinearity.
|
||||
|
||||
return self.output_proj(x).contiguous()
|
||||
|
||||
class LatentDistribution:
|
||||
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||
"""Initialize latent distribution.
|
||||
|
||||
Args:
|
||||
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
assert mean.shape == logvar.shape
|
||||
self.mean = mean
|
||||
self.logvar = logvar
|
||||
|
||||
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||
if temperature == 0.0:
|
||||
return self.mean
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||
else:
|
||||
assert noise.device == self.mean.device
|
||||
noise = noise.to(self.mean.dtype)
|
||||
|
||||
if temperature != 1.0:
|
||||
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||
|
||||
# Just Gaussian sample with no scaling of variance.
|
||||
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
latent_dim: int,
|
||||
temporal_reductions: List[int],
|
||||
spatial_reductions: List[int],
|
||||
prune_bottlenecks: List[bool],
|
||||
has_attentions: List[bool],
|
||||
affine: bool = True,
|
||||
bias: bool = True,
|
||||
input_is_conv_1x1: bool = False,
|
||||
padding_mode: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_reductions = temporal_reductions
|
||||
self.spatial_reductions = spatial_reductions
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.fourier_features = FourierFeatures()
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
num_down_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == num_down_blocks + 2
|
||||
|
||||
layers = (
|
||||
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||
if not input_is_conv_1x1
|
||||
else [Conv1x1(in_channels, ch[0])]
|
||||
)
|
||||
|
||||
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||
assert len(has_attentions) == num_down_blocks + 2
|
||||
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||
|
||||
for _ in range(num_res_blocks[0]):
|
||||
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||
prune_bottlenecks = prune_bottlenecks[1:]
|
||||
has_attentions = has_attentions[1:]
|
||||
|
||||
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||
for i in range(num_down_blocks):
|
||||
layer = DownsampleBlock(
|
||||
ch[i],
|
||||
ch[i + 1],
|
||||
num_res_blocks=num_res_blocks[i + 1],
|
||||
temporal_reduction=temporal_reductions[i],
|
||||
spatial_reduction=spatial_reductions[i],
|
||||
prune_bottleneck=prune_bottlenecks[i],
|
||||
has_attention=has_attentions[i],
|
||||
affine=affine,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
layers.append(layer)
|
||||
|
||||
# Additional blocks.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
# Output layers.
|
||||
self.output_norm = norm_fn(ch[-1])
|
||||
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||
|
||||
@property
|
||||
def temporal_downsample(self):
|
||||
return math.prod(self.temporal_reductions)
|
||||
|
||||
@property
|
||||
def spatial_downsample(self):
|
||||
return math.prod(self.spatial_reductions)
|
||||
|
||||
def forward(self, x) -> LatentDistribution:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||
logvar: Shape: [B, latent_dim, t, h, w].
|
||||
"""
|
||||
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||
x = self.fourier_features(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
x = self.output_norm(x)
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.output_proj(x)
|
||||
|
||||
means, logvar = torch.chunk(x, 2, dim=1)
|
||||
|
||||
assert means.ndim == 5
|
||||
assert logvar.shape == means.shape
|
||||
assert means.size(1) == self.latent_dim
|
||||
|
||||
return LatentDistribution(means, logvar)
|
||||
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
in_channels=15,
|
||||
base_channels=64,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
temporal_reductions=[1, 2, 3],
|
||||
spatial_reductions=[2, 2, 2],
|
||||
prune_bottlenecks=[False, False, False, False, False],
|
||||
has_attentions=[False, True, True, True, True],
|
||||
affine=True,
|
||||
bias=True,
|
||||
input_is_conv_1x1=True,
|
||||
padding_mode="replicate"
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
temporal_expansions=[1, 2, 3],
|
||||
spatial_expansions=[2, 2, 2],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
has_attention=[False, False, False, False, False],
|
||||
padding_mode="replicate",
|
||||
output_norm=False,
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x).mode()
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
@@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
|
||||
style=None,
|
||||
return_dict=False,
|
||||
control=None,
|
||||
transformer_options=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
@@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
encoder_hidden_states = context
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
@@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
controls = None
|
||||
if control:
|
||||
controls = control.get("output", None)
|
||||
@@ -372,12 +373,23 @@ class HunYuanDiT(nn.Module):
|
||||
for layer, block in enumerate(self.blocks):
|
||||
if layer > self.depth // 2:
|
||||
if controls is not None:
|
||||
skip = skips.pop() + controls.pop()
|
||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
||||
skip = None
|
||||
|
||||
if ("double_block", layer) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
|
||||
|
||||
if layer < (self.depth // 2 - 1):
|
||||
skips.append(x)
|
||||
|
||||
514
comfy/ldm/lightricks/model.py
Normal file
514
comfy/ldm/lightricks/model.py
Normal file
@@ -0,0 +1,514 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
# else:
|
||||
# self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
"""
|
||||
For PixArt-Alpha.
|
||||
|
||||
Reference:
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = 0.0
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
x,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
output_width=orig_shape[4],
|
||||
output_num_frames=orig_shape[2],
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
||||
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
def patchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
output_height = output_height // self._patch_size[1]
|
||||
output_width = output_width // self._patch_size[2]
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q) ",
|
||||
f=output_num_frames,
|
||||
h=output_height,
|
||||
w=output_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
dilation = (dilation, 1, 1)
|
||||
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
padding = (0, height_pad, width_pad)
|
||||
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.conv.weight
|
||||
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
@@ -0,0 +1,698 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Any, Mapping, Optional, Tuple, Union, List
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self.blocks_desc = blocks
|
||||
|
||||
in_channels = in_channels * patch_size**2
|
||||
output_channel = base_channels
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in blocks:
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
|
||||
self.down_blocks.append(block)
|
||||
|
||||
# out
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == "per_channel":
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
num_dims = sample.dim()
|
||||
|
||||
if num_dims == 4:
|
||||
# For shape (B, C, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
elif num_dims == 5:
|
||||
# For shape (B, C, F, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
causal (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use causal convolutions or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.layers_per_block = layers_per_block
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.blocks_desc = blocks
|
||||
|
||||
# Compute output channel to be product of all channel-multiplier blocks
|
||||
output_channel = base_channels
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
block_params = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
|
||||
self.up_blocks.append(block)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
# assert target_shape is not None, "target_shape must be provided"
|
||||
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
in_channels, height, width)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
)
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
||||
) -> torch.FloatTensor:
|
||||
for resnet in self.res_blocks:
|
||||
hidden_states = resnet(hidden_states, causal=causal)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, stride, residual=False):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = math.prod(stride) * in_channels
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
)
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, "b d h w c -> b c d h w")
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A Resnet block.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to be `None`):
|
||||
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
||||
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm1 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
self.conv1 = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm2 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
self.conv2 = make_conv_nd(
|
||||
dims,
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.conv_shortcut = (
|
||||
make_linear_nd(
|
||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||
)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.norm3 = (
|
||||
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def encode(self, x):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
||||
|
||||
83
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
83
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .dual_conv3d import DualConv3d
|
||||
from .causal_conv3d import CausalConv3d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def make_conv_nd(
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def make_linear_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias=True,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
elif dims == 3 or dims == (2, 1):
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class DualConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if kernel_size == (1, 1, 1):
|
||||
raise ValueError(
|
||||
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
||||
)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation, dilation)
|
||||
|
||||
# Set parameters for convolutions
|
||||
self.groups = groups
|
||||
self.bias = bias
|
||||
|
||||
# Define the size of the channels after the first convolution
|
||||
intermediate_channels = (
|
||||
out_channels if in_channels < out_channels else in_channels
|
||||
)
|
||||
|
||||
# Define parameters for the first convolution
|
||||
self.weight1 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
intermediate_channels,
|
||||
in_channels // groups,
|
||||
1,
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
)
|
||||
)
|
||||
self.stride1 = (1, stride[1], stride[2])
|
||||
self.padding1 = (0, padding[1], padding[2])
|
||||
self.dilation1 = (1, dilation[1], dilation[2])
|
||||
if bias:
|
||||
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
||||
else:
|
||||
self.register_parameter("bias1", None)
|
||||
|
||||
# Define parameters for the second convolution
|
||||
self.weight2 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
||||
)
|
||||
)
|
||||
self.stride2 = (stride[0], 1, 1)
|
||||
self.padding2 = (padding[0], 0, 0)
|
||||
self.dilation2 = (dilation[0], 1, 1)
|
||||
if bias:
|
||||
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter("bias2", None)
|
||||
|
||||
# Initialize weights and biases
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
|
||||
if self.bias:
|
||||
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
||||
bound1 = 1 / math.sqrt(fan_in1)
|
||||
nn.init.uniform_(self.bias1, -bound1, bound1)
|
||||
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
||||
bound2 = 1 / math.sqrt(fan_in2)
|
||||
nn.init.uniform_(self.bias2, -bound2, bound2)
|
||||
|
||||
def forward(self, x, use_conv3d=False, skip_time_conv=False):
|
||||
if use_conv3d:
|
||||
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
||||
else:
|
||||
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
||||
|
||||
def forward_with_3d(self, x, skip_time_conv):
|
||||
# First convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight1,
|
||||
self.bias1,
|
||||
self.stride1,
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
return x
|
||||
|
||||
# Second convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight2,
|
||||
self.bias2,
|
||||
self.stride2,
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_2d(self, x, skip_time_conv):
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
# First 2D convolution
|
||||
x = rearrange(x, "b c d h w -> (b d) c h w")
|
||||
# Squeeze the depth dimension out of weight1 since it's 1
|
||||
weight1 = self.weight1.squeeze(2)
|
||||
# Select stride, padding, and dilation for the 2D convolution
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
if skip_time_conv:
|
||||
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
||||
return x
|
||||
|
||||
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
||||
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
||||
|
||||
# Reshape weight2 to match the expected dimensions for conv1d
|
||||
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
||||
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.weight2
|
||||
|
||||
|
||||
def test_dual_conv3d_consistency():
|
||||
# Initialize parameters
|
||||
in_channels = 3
|
||||
out_channels = 5
|
||||
kernel_size = (3, 3, 3)
|
||||
stride = (2, 2, 2)
|
||||
padding = (1, 1, 1)
|
||||
|
||||
# Create an instance of the DualConv3d class
|
||||
dual_conv3d = DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Example input tensor
|
||||
test_input = torch.randn(1, 3, 10, 10, 10)
|
||||
|
||||
# Perform forward passes with both 3D and 2D settings
|
||||
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
|
||||
output_2d = dual_conv3d(test_input, use_conv3d=False)
|
||||
|
||||
# Assert that the outputs from both methods are sufficiently close
|
||||
assert torch.allclose(
|
||||
output_conv3d, output_2d, atol=1e-6
|
||||
), "Outputs are not consistent between 3D and 2D convolutions."
|
||||
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self, dim=1, eps=1e-8):
|
||||
super(PixelNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
|
||||
@@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
if mask.shape[1] == 1:
|
||||
s1 += mask
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
@@ -372,10 +375,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
@@ -393,6 +396,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
return out
|
||||
|
||||
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
||||
SDP_BATCH_LIMIT = 2**15
|
||||
else:
|
||||
#TODO: other GPUs ?
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
@@ -404,10 +414,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if SDP_BATCH_LIMIT >= q.shape[0]:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from ..attention import optimized_attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
import comfy.ops
|
||||
@@ -97,7 +97,7 @@ class PatchEmbed(nn.Module):
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# B, C, H, W = x.shape
|
||||
# if self.img_size is not None:
|
||||
# if self.strict_img_size:
|
||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
@@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv = self.pre_attention(x)
|
||||
q, k, v = self.pre_attention(x)
|
||||
x = optimized_attention(
|
||||
qkv, num_heads=self.num_heads
|
||||
q, k, v, heads=self.num_heads
|
||||
)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
@@ -355,29 +353,9 @@ class RMSNorm(torch.nn.Module):
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
@@ -437,6 +415,7 @@ class DismantledBlock(nn.Module):
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -460,6 +439,24 @@ class DismantledBlock(nn.Module):
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.x_block_self_attn = True
|
||||
self.attn2 = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
self.x_block_self_attn = False
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = operations.LayerNorm(
|
||||
@@ -486,7 +483,11 @@ class DismantledBlock(nn.Module):
|
||||
multiple_of=256,
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if not scale_mod_only:
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
@@ -547,14 +548,64 @@ class DismantledBlock(nn.Module):
|
||||
)
|
||||
return x
|
||||
|
||||
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert self.x_block_self_attn
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return qkv, qkv2, (
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
)
|
||||
|
||||
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
|
||||
assert not self.pre_only
|
||||
attn1 = self.attn.post_attention(attn)
|
||||
attn2 = self.attn2.post_attention(attn2)
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
if self.x_block_self_attn:
|
||||
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
|
||||
attn, _ = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
attn2, _ = optimized_attention(
|
||||
qkv2[0], qkv2[1], qkv2[2],
|
||||
num_heads=self.attn2.num_heads,
|
||||
)
|
||||
return self.post_attention_x(attn, attn2, *intermediates)
|
||||
else:
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
@@ -569,7 +620,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o = []
|
||||
for t in range(3):
|
||||
@@ -577,8 +631,8 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
qkv = tuple(o)
|
||||
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=x_block.attn.num_heads,
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=x_block.attn.num_heads,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
@@ -590,7 +644,14 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
|
||||
else:
|
||||
context = None
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
if x_block.x_block_self_attn:
|
||||
attn2 = optimized_attention(
|
||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||
heads=x_block.attn2.num_heads,
|
||||
)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
return context, x
|
||||
|
||||
|
||||
@@ -605,8 +666,13 @@ class JointBlock(nn.Module):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=x_block_self_attn,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(
|
||||
@@ -662,7 +728,7 @@ class SelfAttentionContext(nn.Module):
|
||||
def forward(self, x):
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.dim_head)
|
||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
||||
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
|
||||
return self.proj(x)
|
||||
|
||||
class ContextProcessorBlock(nn.Module):
|
||||
@@ -721,9 +787,12 @@ class MMDiT(nn.Module):
|
||||
qk_norm: Optional[str] = None,
|
||||
qkv_bias: bool = True,
|
||||
context_processor_layers = None,
|
||||
x_block_self_attn: bool = False,
|
||||
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||
context_size = 4096,
|
||||
num_blocks = None,
|
||||
final_layer = True,
|
||||
skip_blocks = False,
|
||||
dtype = None, #TODO
|
||||
device = None,
|
||||
operations = None,
|
||||
@@ -738,6 +807,7 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||
|
||||
# hidden_size = default(hidden_size, 64 * depth)
|
||||
# num_heads = default(num_heads, hidden_size // 64)
|
||||
@@ -795,26 +865,28 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed = None
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
if not skip_blocks:
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
@@ -877,7 +949,9 @@ class MMDiT(nn.Module):
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
@@ -889,14 +963,25 @@ class MMDiT(nn.Module):
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
blocks = len(self.joint_blocks)
|
||||
for i in range(blocks):
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||
context = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
@@ -914,6 +999,7 @@ class MMDiT(nn.Module):
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
@@ -935,7 +1021,7 @@ class MMDiT(nn.Module):
|
||||
if context is not None:
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context, control)
|
||||
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x[:,:,:hw[-2],:hw[-1]]
|
||||
@@ -949,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from .util import (
|
||||
)
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -47,6 +48,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@@ -819,6 +829,13 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
@@ -842,6 +859,11 @@ class UNetModel(nn.Module):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if "emb_patch" in transformer_patches:
|
||||
patch = transformer_patches["emb_patch"]
|
||||
for p in patch:
|
||||
emb = p(emb, self.model_channels, transformer_options)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@@ -234,6 +234,8 @@ def efficient_dot_product_attention(
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
if mask.shape[1] == 1:
|
||||
return mask
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
|
||||
183
comfy/lora.py
183
comfy/lora.py
@@ -16,6 +16,7 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
@@ -32,7 +33,7 @@ LORA_CLIP_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
def load_lora(lora, to_load, log_missing=True):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@@ -48,10 +49,20 @@ def load_lora(lora, to_load):
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
@@ -71,6 +82,10 @@ def load_lora(lora, to_load):
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif mochi_lora in lora.keys():
|
||||
A_name = mochi_lora
|
||||
B_name = "{}.lora_A".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
@@ -81,7 +96,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@@ -192,17 +207,28 @@ def load_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
if log_missing:
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
|
||||
return patch_dict
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
clip_g_present = False
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
@@ -226,6 +252,7 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
clip_g_present = True
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
@@ -241,10 +268,18 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
t5_index = 1
|
||||
if clip_g_present:
|
||||
t5_index += 1
|
||||
if clip_l_present:
|
||||
t5_index += 1
|
||||
if t5_index == 2:
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||
t5_index += 1
|
||||
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
@@ -268,11 +303,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
@@ -280,6 +318,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
@@ -302,6 +341,10 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||
key_map[key_lora] = to
|
||||
|
||||
|
||||
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
@@ -323,14 +366,21 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
to = diffusers_keys[k]
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
|
||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
@@ -347,7 +397,40 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Pad a tensor to a new shape with zeros.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The original tensor to be padded.
|
||||
new_shape (List[int]): The desired shape of the padded tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||
|
||||
Note:
|
||||
If the new shape is smaller than the original tensor in any dimension,
|
||||
the original tensor will be truncated in that dimension.
|
||||
"""
|
||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||
|
||||
if len(new_shape) != len(tensor.shape):
|
||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||
|
||||
# Create a new tensor filled with zeros
|
||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
# Create slicing tuples for both tensors
|
||||
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
|
||||
# Copy the original tensor into the new tensor
|
||||
padded_tensor[new_slices] = tensor[orig_slices]
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
@@ -366,7 +449,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
@@ -375,16 +458,34 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
diff: torch.Tensor = v[0]
|
||||
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||
if do_pad_weight and diff.shape != weight.shape:
|
||||
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||
|
||||
if strength != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
if diff.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
elif patch_type == "model_as_lora":
|
||||
target_weight: torch.Tensor = v[0]
|
||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
@@ -398,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -444,7 +545,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -481,28 +582,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
old_glora = False
|
||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||
rank = v[0].shape[0]
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = v[1].shape[0]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / rank
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
|
||||
17
comfy/lora_convert.py
Normal file
17
comfy/lora_convert.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
return sd
|
||||
@@ -24,19 +24,25 @@ from comfy.ldm.cascade.stage_b import StageB
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
||||
import comfy.ldm.aura.mmdit
|
||||
import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@@ -93,10 +99,12 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device = device
|
||||
self.current_patcher: 'ModelPatcher' = None
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
@@ -117,6 +125,13 @@ class BaseModel(torch.nn.Module):
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._apply_model,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
|
||||
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
@@ -151,8 +166,7 @@ class BaseModel(torch.nn.Module):
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
def concat_cond(self, **kwargs):
|
||||
if len(self.concat_keys) > 0:
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
@@ -191,7 +205,14 @@ class BaseModel(torch.nn.Module):
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
return data
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
if concat_cond is not None:
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
|
||||
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
@@ -244,6 +265,10 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
@@ -517,9 +542,7 @@ class SD_X4Upscaler(BaseModel):
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
@@ -531,18 +554,15 @@ class IP2P:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@@ -703,6 +723,44 @@ class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
try:
|
||||
#Handle Flux control loras dynamically changing the img_in weight.
|
||||
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
|
||||
except:
|
||||
#Some cases like tensorrt might not have the weights accessible
|
||||
num_channels = self.model_config.unet_config["in_channels"]
|
||||
|
||||
out_channels = self.model_config.unet_config["out_channels"]
|
||||
|
||||
if num_channels <= out_channels:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
image = self.process_latent_in(image)
|
||||
if num_channels <= out_channels * 2:
|
||||
return image
|
||||
|
||||
#inpaint model
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.ones_like(noise)[:, :1]
|
||||
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
print(mask.shape)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
@@ -713,3 +771,38 @@ class Flux(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class LTXV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guiding_latent = kwargs.get("guiding_latent", None)
|
||||
if guiding_latent is not None:
|
||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
return out
|
||||
|
||||
@@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||
if context_processor in state_dict_keys:
|
||||
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||
unet_config["x_block_self_attn_layers"] = []
|
||||
for key in state_dict_keys:
|
||||
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
|
||||
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
|
||||
unet_config["x_block_self_attn_layers"].append(int(layer))
|
||||
return unet_config
|
||||
|
||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||
@@ -132,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
patch_size = 2
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
@@ -145,6 +156,38 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "mochi_preview"
|
||||
dit_config["depth"] = 48
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["hidden_size_x"] = 3072
|
||||
dit_config["hidden_size_y"] = 1536
|
||||
dit_config["mlp_ratio_x"] = 4.0
|
||||
dit_config["mlp_ratio_y"] = 4.0
|
||||
dit_config["learn_sigma"] = False
|
||||
dit_config["in_channels"] = 12
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["out_bias"] = True
|
||||
dit_config["attn_drop"] = 0.0
|
||||
dit_config["patch_embed_bias"] = True
|
||||
dit_config["posenc_preserve_area"] = True
|
||||
dit_config["timestep_mlp_bias"] = True
|
||||
dit_config["attend_to_padding"] = False
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
dit_config["use_t5"] = True
|
||||
dit_config["t5_feat_dim"] = 4096
|
||||
dit_config["t5_token_length"] = 256
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
return dit_config
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -286,9 +329,16 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
return None
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
@@ -501,7 +551,11 @@ def model_config_from_diffusers_unet(state_dict):
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
out_sd = {}
|
||||
|
||||
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||
if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||
elif 'x_embedder.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
@@ -510,10 +564,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ from comfy.cli_args import args
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@@ -45,6 +47,7 @@ cpu_state = CPUState.GPU
|
||||
total_vram = 0
|
||||
|
||||
xpu_available = False
|
||||
torch_version = ""
|
||||
try:
|
||||
torch_version = torch.version.__version__
|
||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||
@@ -144,7 +147,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
||||
logging.info("pytorch version: {}".format(torch_version))
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -286,11 +289,27 @@ def module_size(module):
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.weights_loaded = False
|
||||
self.real_model = None
|
||||
self.currently_used = True
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||
|
||||
def _switch_parent(self):
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model()
|
||||
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
@@ -305,32 +324,23 @@ class LoadedModel:
|
||||
return self.model_memory()
|
||||
|
||||
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||
patch_model_to = self.device
|
||||
|
||||
self.model.model_patches_to(self.device)
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
# if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||
real_model = self.model.model
|
||||
|
||||
if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
with torch.no_grad():
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
return real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
||||
@@ -343,18 +353,26 @@ class LoadedModel:
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
def model_use_more_vram(self, extra_memory):
|
||||
return self.model.partially_load(self.device, extra_memory)
|
||||
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
|
||||
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
|
||||
def __del__(self):
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
for m in loaded_models:
|
||||
if m.device == device:
|
||||
@@ -369,12 +387,11 @@ def offloaded_memory(loaded_models, device):
|
||||
offloaded_mem += m.model_offloaded_memory()
|
||||
return offloaded_mem
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 1.2
|
||||
WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
|
||||
if any(platform.win32_ver()):
|
||||
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
@@ -383,36 +400,11 @@ if args.reserve_vram is not None:
|
||||
def extra_reserved_memory():
|
||||
return EXTRA_RESERVED_VRAM
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
if len(to_unload) == 0:
|
||||
return True
|
||||
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
|
||||
return unload_weight
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@@ -420,8 +412,8 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
@@ -449,6 +441,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
@@ -461,11 +454,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
models = set(models)
|
||||
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except:
|
||||
@@ -473,51 +464,35 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
|
||||
if loaded is None:
|
||||
loaded.currently_used = True
|
||||
models_to_load.append(loaded)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem < minimum_memory_required:
|
||||
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
||||
models_to_load = free_memory(minimum_memory_required, d)
|
||||
logging.info("{} models unloaded.".format(len(models_to_load)))
|
||||
else:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
if len(models_to_load) == 0:
|
||||
return
|
||||
|
||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
for loaded_model in models_to_load:
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if loaded_model.model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
for i in to_unload:
|
||||
current_loaded_models.pop(i).model.detach(unpatch_all=False)
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_already_loaded:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
@@ -539,17 +514,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem > minimum_memory_required:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
return
|
||||
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@@ -563,21 +529,35 @@ def loaded_models(only_currently_used=False):
|
||||
output.append(m.model)
|
||||
return output
|
||||
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
|
||||
if do_gc:
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
#TODO: very fragile function needs improvement
|
||||
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||
if num_refs <= 2:
|
||||
if not keep_clone_weights_loaded:
|
||||
to_delete = [i] + to_delete
|
||||
#TODO: find a less fragile way to do this.
|
||||
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
|
||||
to_delete = [i] + to_delete
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
x.model_unload()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
@@ -621,6 +601,12 @@ def maximum_vram_for_weights(device=None):
|
||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if model_params < 0:
|
||||
model_params = 1000000000000000000000
|
||||
if args.fp32_unet:
|
||||
return torch.float32
|
||||
if args.fp64_unet:
|
||||
return torch.float64
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@@ -640,6 +626,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
pass
|
||||
|
||||
if fp8_dtype is not None:
|
||||
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||
return fp8_dtype
|
||||
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if model_params * 2 > free_model_memory:
|
||||
return fp8_dtype
|
||||
@@ -664,7 +653,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
@@ -833,27 +822,21 @@ def force_channels_last():
|
||||
#TODO
|
||||
return False
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
device_supports_cast = True
|
||||
elif tensor.dtype == torch.bfloat16:
|
||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
||||
device_supports_cast = True
|
||||
elif is_intel_xpu():
|
||||
device_supports_cast = True
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
non_blocking = device_should_use_non_blocking(device)
|
||||
|
||||
if device_supports_cast:
|
||||
if copy:
|
||||
if tensor.device == device:
|
||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
@@ -892,7 +875,7 @@ def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
||||
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
||||
upcast = True
|
||||
except:
|
||||
pass
|
||||
@@ -999,7 +982,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
return True
|
||||
if WINDOWS or manual_cast:
|
||||
return True
|
||||
else:
|
||||
return False #weird linux behavior where fp32 is faster
|
||||
|
||||
if manual_cast:
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
@@ -1055,6 +1041,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 9:
|
||||
return True
|
||||
@@ -1062,6 +1051,14 @@ def supports_fp8_compute(device=None):
|
||||
return False
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,25 @@ import torch
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import math
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
@@ -48,7 +67,7 @@ class CONST:
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
@@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
timesteps = sampling_settings.get("timesteps", 1000)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
if zsnr is None:
|
||||
zsnr = sampling_settings.get("zsnr", False)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
@@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
if zsnr:
|
||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
|
||||
103
comfy/ops.py
103
comfy/ops.py
@@ -19,20 +19,12 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
import comfy.float
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
@@ -47,12 +39,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = s.bias_function is not None
|
||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
bias = s.bias_function(bias)
|
||||
|
||||
has_function = s.weight_function is not None
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
@@ -258,20 +250,29 @@ def fp8_linear(self, input):
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
|
||||
if len(input.shape) == 3:
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w = w.t()
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
if scale_input is None:
|
||||
scale_input = scale_weight
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_weight = scale_weight.to(input.device)
|
||||
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
@@ -281,7 +282,11 @@ def fp8_linear(self, input):
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input.shape[0], -1)
|
||||
|
||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||
|
||||
return None
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
@@ -299,11 +304,63 @@ class fp8_ops(manual_cast):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if override_dtype is not None:
|
||||
kwargs['dtype'] = override_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not hasattr(self, 'scale_weight'):
|
||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
|
||||
if not scale_input:
|
||||
self.scale_input = None
|
||||
|
||||
if not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if fp8_matrix_mult:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight
|
||||
else:
|
||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||
|
||||
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
return fp8_ops
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
if args.fast:
|
||||
if comfy.model_management.supports_fp8_compute(load_device):
|
||||
return fp8_ops
|
||||
|
||||
return manual_cast
|
||||
|
||||
156
comfy/patcher_extension.py
Normal file
156
comfy/patcher_extension.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
ON_PRE_RUN = "on_pre_run"
|
||||
ON_PREPARE_STATE = "on_prepare_state"
|
||||
ON_APPLY_HOOKS = "on_apply_hooks"
|
||||
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
||||
ON_INJECT_MODEL = "on_inject_model"
|
||||
ON_EJECT_MODEL = "on_eject_model"
|
||||
|
||||
# callbacks dict is in the format:
|
||||
# {"call_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
|
||||
|
||||
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
|
||||
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||
c.append(callback)
|
||||
|
||||
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
c_list = []
|
||||
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||
c_list.extend(callbacks.get(call_type, {}).get(key, []))
|
||||
return c_list
|
||||
|
||||
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
c_list = []
|
||||
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||
for c in callbacks.get(call_type, {}).values():
|
||||
c_list.extend(c)
|
||||
return c_list
|
||||
|
||||
class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
# wrappers dict is in the format:
|
||||
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
|
||||
|
||||
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
|
||||
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
w_list = []
|
||||
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
|
||||
return w_list
|
||||
|
||||
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
w_list = []
|
||||
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||
for w in wrappers.get(wrapper_type, {}).values():
|
||||
w_list.extend(w)
|
||||
return w_list
|
||||
|
||||
class WrapperExecutor:
|
||||
"""Handles call stack of wrappers around a function in an ordered manner."""
|
||||
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
|
||||
# NOTE: class_obj exists so that wrappers surrounding a class method can access
|
||||
# the class instance at runtime via executor.class_obj
|
||||
self.original = original
|
||||
self.class_obj = class_obj
|
||||
self.wrappers = wrappers.copy()
|
||||
self.idx = idx
|
||||
self.is_last = idx == len(wrappers)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls the next wrapper or original function, whichever is appropriate."""
|
||||
new_executor = self._create_next_executor()
|
||||
return new_executor.execute(*args, **kwargs)
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
||||
args = list(args)
|
||||
kwargs = dict(kwargs)
|
||||
if self.is_last:
|
||||
return self.original(*args, **kwargs)
|
||||
return self.wrappers[self.idx](self, *args, **kwargs)
|
||||
|
||||
def _create_next_executor(self) -> 'WrapperExecutor':
|
||||
new_idx = self.idx + 1
|
||||
if new_idx > len(self.wrappers):
|
||||
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
|
||||
if self.class_obj is None:
|
||||
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
||||
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
||||
|
||||
@classmethod
|
||||
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
||||
|
||||
@classmethod
|
||||
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj, wrappers, idx=idx)
|
||||
|
||||
class PatcherInjection:
|
||||
def __init__(self, inject: Callable, eject: Callable):
|
||||
self.inject = inject
|
||||
self.eject = eject
|
||||
|
||||
def copy_nested_dicts(input_dict: dict):
|
||||
new_dict = input_dict.copy()
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
new_dict[key] = copy_nested_dicts(value)
|
||||
elif isinstance(value, list):
|
||||
new_dict[key] = value.copy()
|
||||
return new_dict
|
||||
|
||||
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||
if copy_dict1:
|
||||
merged_dict = copy_nested_dicts(dict1)
|
||||
else:
|
||||
merged_dict = dict1
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict):
|
||||
curr_value = merged_dict.setdefault(key, {})
|
||||
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
||||
elif isinstance(value, list):
|
||||
merged_dict.setdefault(key, []).extend(value)
|
||||
else:
|
||||
merged_dict[key] = value
|
||||
return merged_dict
|
||||
@@ -1,22 +1,61 @@
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
if isinstance(c[model_type], list):
|
||||
models += c[model_type]
|
||||
else:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||
cnets: list[ControlBase] = []
|
||||
for c in cond:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook: comfy.hooks.Hook
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
if 'control' in c:
|
||||
cnets.append(c['control'])
|
||||
|
||||
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
|
||||
if cnet.extra_hooks is not None:
|
||||
_list.append(cnet.extra_hooks)
|
||||
if cnet.previous_controlnet is None:
|
||||
return _list
|
||||
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
|
||||
|
||||
hooks_list = []
|
||||
cnets = set(cnets)
|
||||
for base_cnet in cnets:
|
||||
get_extra_hooks_from_cnet(base_cnet, hooks_list)
|
||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||
if extra_hooks is not None:
|
||||
for hook in extra_hooks.hooks:
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
|
||||
return hooks_dict
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
@@ -26,17 +65,22 @@ def convert_cond(cond):
|
||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
temp["uuid"] = uuid.uuid4()
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets = []
|
||||
cnets: list[ControlBase] = []
|
||||
gligen = []
|
||||
add_models = []
|
||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
||||
|
||||
for k in conds:
|
||||
cnets += get_models_from_cond(conds[k], "control")
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
@@ -47,7 +91,9 @@ def get_additional_models(conds, dtype):
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
||||
models = control_models + gligen + add_models + hook_models
|
||||
|
||||
return models, inference_memory
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
@@ -57,10 +103,11 @@ def cleanup_additional_models(models):
|
||||
m.cleanup()
|
||||
|
||||
|
||||
def prepare_sampling(model, noise_shape, conds):
|
||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||
device = model.load_device
|
||||
real_model = None
|
||||
real_model: 'BaseModel' = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||
@@ -76,3 +123,14 @@ def cleanup_models(conds, models):
|
||||
control_cleanup += get_models_from_cond(conds[k], "control")
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
# check for hooks in conds - if not registered, see if can be applied
|
||||
hooks = {}
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
# register hooks on model/model_options
|
||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
||||
|
||||
@@ -1,12 +1,22 @@
|
||||
from __future__ import annotations
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
import torch
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.samplers
|
||||
import comfy.sampler_helpers
|
||||
import scipy
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
@@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
for c in model_conds:
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
hooks = conds.get('hooks', None)
|
||||
control = conds.get('control', None)
|
||||
|
||||
patches = None
|
||||
@@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
@@ -138,110 +149,184 @@ def cond_cat(c_list):
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
||||
# need to figure out remaining unmasked area for conds
|
||||
default_mults = []
|
||||
for _ in default_conds:
|
||||
default_mults.append(torch.ones_like(x_in))
|
||||
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
|
||||
for lora_hooks, to_run in hooked_to_run.items():
|
||||
for cond_obj, i in to_run:
|
||||
# if no default_cond for cond_type, do nothing
|
||||
if len(default_conds[i]) == 0:
|
||||
continue
|
||||
area: list[int] = cond_obj.area
|
||||
if area is not None:
|
||||
curr_default_mult: torch.Tensor = default_mults[i]
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
|
||||
curr_default_mult -= cond_obj.mult
|
||||
else:
|
||||
default_mults[i] -= cond_obj.mult
|
||||
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
|
||||
for i, mult in enumerate(default_mults):
|
||||
# if no default_cond for cond type, do nothing
|
||||
if len(default_conds[i]) == 0:
|
||||
continue
|
||||
torch.nn.functional.relu(mult, inplace=True)
|
||||
# if mult is all zeros, then don't add default_cond
|
||||
if torch.max(mult) == 0.0:
|
||||
continue
|
||||
|
||||
cond = default_conds[i]
|
||||
for x in cond:
|
||||
# do get_area_and_mult to get all the expected values
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
# replace p's mult with calculated mult
|
||||
p = p._replace(mult=mult)
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
to_run = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
to_run += [(p, i)]
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
@@ -358,11 +443,35 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
||||
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||
|
||||
sigs = []
|
||||
last_t = -1
|
||||
for t in ts:
|
||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||
if t != last_t:
|
||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||
last_t = t
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
|
||||
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
|
||||
if steps == 1:
|
||||
sigma_schedule = [1.0, 0.0]
|
||||
else:
|
||||
if linear_steps is None:
|
||||
linear_steps = steps // 2
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
||||
quadratic_steps = steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||||
const = quadratic_coef * (linear_steps ** 2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||||
for i in range(linear_steps, steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||
|
||||
def get_mask_aabb(masks):
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||
@@ -476,10 +585,15 @@ def calculate_start_end_timesteps(model, conds):
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
# handle clip hook schedule, if needed
|
||||
if 'clip_start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
|
||||
timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
|
||||
else:
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
|
||||
if (timestep_start is not None) or (timestep_end is not None):
|
||||
n = x.copy()
|
||||
@@ -570,8 +684,8 @@ class Sampler:
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
@@ -649,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
if k != kk:
|
||||
create_cond_with_same_area_if_none(conds[kk], c)
|
||||
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook.initialize_timesteps(model)
|
||||
|
||||
for k in conds:
|
||||
pre_run_control(model, conds[k])
|
||||
|
||||
@@ -661,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
return conds
|
||||
|
||||
|
||||
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
||||
# determine which ControlNets have extra_hooks that should be combined with normal hooks
|
||||
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
control: 'ControlBase' = kk['control']
|
||||
extra_hooks = control.get_extra_hooks()
|
||||
if len(extra_hooks) > 0:
|
||||
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||
to_replace = hook_replacement.setdefault((control, hooks), [])
|
||||
to_replace.append(kk)
|
||||
# if nothing to replace, do nothing
|
||||
if len(hook_replacement) == 0:
|
||||
return
|
||||
|
||||
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
|
||||
# on the cond dicts
|
||||
for key, conds_to_modify in hook_replacement.items():
|
||||
control = key[0]
|
||||
hooks = key[1]
|
||||
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
|
||||
# if combined hooks are not None, set as new hooks for all relevant conds
|
||||
if hooks is not None:
|
||||
for cond in conds_to_modify:
|
||||
cond['hooks'] = hooks
|
||||
|
||||
|
||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
hooks_set = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
hooks_set.add(kk.get('hooks', None))
|
||||
return len(hooks_set)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher = model_patcher
|
||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||
self.model_options = model_patcher.model_options
|
||||
self.original_conds = {}
|
||||
self.cfg = 1.0
|
||||
@@ -690,19 +847,17 @@ class CFGGuider:
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_args = {"model_options": self.model_options, "seed":seed}
|
||||
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
||||
|
||||
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
sampler.sample,
|
||||
sampler,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
|
||||
)
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
@@ -713,14 +868,48 @@ class CFGGuider:
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
del self.conds
|
||||
del self.loaded_models
|
||||
return output
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
preprocess_conds_hooks(self.conds)
|
||||
|
||||
try:
|
||||
orig_model_options = self.model_options
|
||||
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
|
||||
orig_hook_mode = self.model_patcher.hook_mode
|
||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.outer_sample,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_options = orig_model_options
|
||||
self.model_patcher.hook_mode = orig_hook_mode
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
return output
|
||||
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
cfg_guider = CFGGuider(model)
|
||||
@@ -729,7 +918,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
@@ -747,6 +936,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
elif scheduler_name == "beta":
|
||||
sigmas = beta_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "linear_quadratic":
|
||||
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
||||
else:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
return sigmas
|
||||
|
||||
289
comfy/sd.py
289
comfy/sd.py
@@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.utils import ProgressBar
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@@ -25,13 +29,18 @@ import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.long_clipl
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.supported_models_base
|
||||
import comfy.taesd.taesd
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@@ -39,6 +48,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
lora = comfy.lora_convert.convert_lora(lora)
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
@@ -70,14 +80,14 @@ class CLIP:
|
||||
clip = target.clip
|
||||
tokenizer = target.tokenizer
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
load_device = model_options.get("load_device", model_management.text_encoder_device())
|
||||
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
|
||||
dtype = model_options.get("dtype", None)
|
||||
if dtype is None:
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
params['dtype'] = dtype
|
||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
|
||||
params['model_options'] = model_options
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
@@ -91,9 +101,13 @@ class CLIP:
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
self.use_clip_schedule = False
|
||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
|
||||
def clone(self):
|
||||
@@ -102,6 +116,8 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
n.layer_idx = self.layer_idx
|
||||
n.use_clip_schedule = self.use_clip_schedule
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
@@ -113,6 +129,69 @@ class CLIP:
|
||||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||
if self.apply_hooks_to_conds:
|
||||
pooled_dict["hooks"] = self.apply_hooks_to_conds
|
||||
return pooled_dict
|
||||
|
||||
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
|
||||
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
|
||||
all_hooks = self.patcher.forced_hooks
|
||||
if all_hooks is None or not self.use_clip_schedule:
|
||||
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
|
||||
return_pooled = "unprojected" if unprojected else True
|
||||
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
|
||||
cond = pooled_dict.pop("cond")
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
else:
|
||||
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
|
||||
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
if unprojected:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
@@ -130,6 +209,7 @@ class CLIP:
|
||||
if len(o) > 2:
|
||||
for k in o[2]:
|
||||
out[k] = o[2][k]
|
||||
self.add_hooks_to_dict(out)
|
||||
return out
|
||||
|
||||
if return_pooled:
|
||||
@@ -170,6 +250,7 @@ class VAE:
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
@@ -239,9 +320,30 @@ class VAE:
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
|
||||
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
|
||||
self.latent_channels = 12
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -297,6 +399,10 @@ class VAE:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@@ -315,6 +421,7 @@ class VAE:
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
pixel_samples = None
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
@@ -322,38 +429,66 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
if len(samples_in.shape) == 3:
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
else:
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
return output.movedim(1,-1)
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
@@ -374,6 +509,12 @@ class VAE:
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
def spacial_compression_decode(self):
|
||||
try:
|
||||
return self.upscale_ratio[-1]
|
||||
except:
|
||||
return self.upscale_ratio
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
@@ -387,6 +528,8 @@ def load_style_model(ckpt_path):
|
||||
keys = model_data.keys()
|
||||
if "style_embedding" in keys:
|
||||
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||
elif "redux_down.weight" in keys:
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
model.load_state_dict(model_data)
|
||||
@@ -399,6 +542,8 @@ class CLIPType(Enum):
|
||||
STABLE_AUDIO = 4
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
MOCHI = 7
|
||||
LTXV = 8
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
@@ -406,8 +551,46 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
|
||||
class TEModel(Enum):
|
||||
CLIP_L = 1
|
||||
CLIP_H = 2
|
||||
CLIP_G = 3
|
||||
T5_XXL = 4
|
||||
T5_XL = 5
|
||||
T5_BASE = 6
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_G
|
||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_H
|
||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_L
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[-1] == 2048:
|
||||
return TEModel.T5_XL
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
return TEModel.T5_BASE
|
||||
return None
|
||||
|
||||
|
||||
def t5xxl_detect(clip_data):
|
||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||
|
||||
for sd in clip_data:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = state_dicts
|
||||
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
@@ -421,64 +604,68 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
te_model = detect_te_model(clip_data[0])
|
||||
if te_model == TEModel.CLIP_G:
|
||||
if clip_type == CLIPType.STABLE_CASCADE:
|
||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||
elif clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
elif te_model == TEModel.CLIP_H:
|
||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
dtype_t5 = weight.dtype
|
||||
if weight.shape[-1] == 4096:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
elif te_model == TEModel.T5_XXL:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif weight.shape[-1] == 2048:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
elif te_model == TEModel.T5_XL:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif te_model == TEModel.T5_BASE:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
else:
|
||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None and w.shape[0] == 248:
|
||||
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
||||
elif clip_type == CLIPType.FLUX:
|
||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
|
||||
dtype_t5 = None
|
||||
if weight is not None:
|
||||
dtype_t5 = weight.dtype
|
||||
|
||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
|
||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
|
||||
parameters = 0
|
||||
tokenizer_data = {}
|
||||
for c in clip_data:
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@@ -544,11 +731,11 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return None
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("weight_dtype", None)
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
@@ -562,7 +749,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
|
||||
@@ -614,6 +800,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
||||
sd = temp_sd
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
|
||||
@@ -640,14 +828,21 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
if dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
|
||||
@@ -80,7 +80,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
def __init__(self, device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||
@@ -94,11 +94,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
config = json.load(f)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
|
||||
if operations is None:
|
||||
operations = comfy.ops.manual_cast
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
if scaled_fp8 is not None:
|
||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
self.max_length = max_length
|
||||
@@ -542,6 +551,7 @@ class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -570,6 +580,7 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
|
||||
@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -40,7 +41,8 @@ class SDXLTokenizer:
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set([dtype])
|
||||
|
||||
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
||||
cut_to = min(l_out.shape[1], g_out.shape[1])
|
||||
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
|
||||
@@ -10,6 +10,8 @@ import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -196,6 +198,8 @@ class SDXL(supported_models_base.BASE):
|
||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
|
||||
self.sampling_settings["zsnr"] = True
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
@@ -529,12 +533,11 @@ class SD3(supported_models_base.BASE):
|
||||
clip_l = True
|
||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_g = True
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
if "dtype_t5" in t5_detect:
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
|
||||
|
||||
class StableAudio(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@@ -653,11 +656,17 @@ class Flux(supported_models_base.BASE):
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
dtype_t5 = None
|
||||
if t5_key in state_dict:
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
|
||||
class FluxInpaint(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
"guidance_embed": True,
|
||||
"in_channels": 96,
|
||||
}
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
class FluxSchnell(Flux):
|
||||
unet_config = {
|
||||
@@ -674,7 +683,63 @@ class FluxSchnell(Flux):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 6.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Mochi
|
||||
|
||||
memory_usage_factor = 2.0 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.GenmoMochi(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
||||
|
||||
class LTXV(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ltxv",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.37,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.LTXV
|
||||
|
||||
memory_usage_factor = 2.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -49,6 +49,8 @@ class BASE:
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
@@ -71,6 +73,7 @@ class BASE:
|
||||
self.unet_config = unet_config.copy()
|
||||
self.sampling_settings = self.sampling_settings.copy()
|
||||
self.latent_format = self.latent_format()
|
||||
self.optimizations = self.optimizations.copy()
|
||||
for x in self.unet_extra_config:
|
||||
self.unet_config[x] = self.unet_extra_config[x]
|
||||
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.model_management
|
||||
from transformers import T5TokenizerFast
|
||||
import torch
|
||||
import os
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class FluxTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -38,8 +35,9 @@ class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_t5])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
@@ -64,8 +62,11 @@ class FluxClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def flux_clip(dtype_t5=None):
|
||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
||||
38
comfy/text_encoders/genmo.py
Normal file
38
comfy/text_encoders/genmo.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
|
||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["attention_mask"] = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class MochiTEModel_(MochiT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return MochiTEModel_
|
||||
@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LongClipModel_(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, *args, **kwargs):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||
|
||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||
|
||||
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||
if w is None:
|
||||
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None and w.shape[0] == 248:
|
||||
tokenizer_data = tokenizer_data.copy()
|
||||
model_options = model_options.copy()
|
||||
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||
model_options["clip_l_class"] = LongClipModel_
|
||||
return tokenizer_data, model_options
|
||||
|
||||
18
comfy/text_encoders/lt.py
Normal file
18
comfy/text_encoders/lt.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.text_encoders.genmo
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
|
||||
|
||||
|
||||
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def ltxv_te(*args, **kwargs):
|
||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||
@@ -8,9 +8,27 @@ import comfy.model_management
|
||||
import logging
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
def t5_xxl_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
|
||||
if t5_key in state_dict:
|
||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
return out
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -20,7 +38,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
@@ -38,11 +57,12 @@ class SD3Tokenizer:
|
||||
return {}
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
@@ -55,7 +75,8 @@ class SD3ClipModel(torch.nn.Module):
|
||||
|
||||
if t5:
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.t5_attention_mask = t5_attention_mask
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
|
||||
self.dtypes.add(dtype_t5)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
@@ -85,6 +106,7 @@ class SD3ClipModel(torch.nn.Module):
|
||||
lg_out = None
|
||||
pooled = None
|
||||
out = None
|
||||
extra = {}
|
||||
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
if self.clip_l is not None:
|
||||
@@ -95,7 +117,8 @@ class SD3ClipModel(torch.nn.Module):
|
||||
if self.clip_g is not None:
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
if lg_out is not None:
|
||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||
cut_to = min(lg_out.shape[1], g_out.shape[1])
|
||||
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
|
||||
else:
|
||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||
else:
|
||||
@@ -108,7 +131,11 @@ class SD3ClipModel(torch.nn.Module):
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
t5_out, t5_pooled = t5_output[:2]
|
||||
if self.t5_attention_mask:
|
||||
extra["attention_mask"] = t5_output[2]["attention_mask"]
|
||||
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
||||
@@ -120,7 +147,7 @@ class SD3ClipModel(torch.nn.Module):
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
return out, pooled
|
||||
return out, pooled, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -130,8 +157,11 @@ class SD3ClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||
return SD3ClipModel_
|
||||
|
||||
127
comfy/utils.py
127
comfy/utils.py
@@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||
|
||||
if len(dtypes) == 0:
|
||||
return None
|
||||
@@ -316,10 +316,18 @@ MMDIT_MAP_BLOCK = {
|
||||
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
||||
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
||||
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
||||
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
|
||||
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
|
||||
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
||||
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
||||
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
||||
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
||||
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
|
||||
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
|
||||
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
|
||||
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
|
||||
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
|
||||
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
|
||||
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
||||
@@ -349,6 +357,12 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
k = "{}.attn2.".format(block_from)
|
||||
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
for k in MMDIT_MAP_BLOCK:
|
||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||
|
||||
@@ -528,6 +542,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
||||
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
@@ -688,9 +704,14 @@ def lanczos(samples, width, height):
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
orig_shape = tuple(samples.shape)
|
||||
if len(orig_shape) > 4:
|
||||
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
||||
samples = samples.movedim(2, 1)
|
||||
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_width = samples.shape[-1]
|
||||
old_height = samples.shape[-2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
@@ -699,48 +720,87 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
||||
else:
|
||||
s = samples
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
out = bislerp(s, width, height)
|
||||
elif upscale_method == "lanczos":
|
||||
return lanczos(s, width, height)
|
||||
out = lanczos(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
if len(orig_shape) == 4:
|
||||
return out
|
||||
|
||||
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
|
||||
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
||||
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
||||
return rows * cols
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
dims = len(tile)
|
||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||
|
||||
if not (isinstance(upscale_amount, (tuple, list))):
|
||||
upscale_amount = [upscale_amount] * dims
|
||||
|
||||
if not (isinstance(overlap, (tuple, list))):
|
||||
overlap = [overlap] * dims
|
||||
|
||||
def get_upscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
if callable(up):
|
||||
return up(val)
|
||||
else:
|
||||
return up * val
|
||||
|
||||
def mult_list_upscale(a):
|
||||
out = []
|
||||
for i in range(len(a)):
|
||||
out.append(round(get_upscale(i, a[i])))
|
||||
return out
|
||||
|
||||
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
|
||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||
# handle entire input fitting in a single tile
|
||||
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||
output[b:b+1] = function(s).to(output_device)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
for it in itertools.product(*positions):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
||||
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(pos * upscale_amount))
|
||||
upscaled.append(round(get_upscale(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
for t in range(feather):
|
||||
for d in range(2, dims + 2):
|
||||
m = mask.narrow(d, t, 1)
|
||||
m *= ((1.0/feather) * (t + 1))
|
||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
||||
m *= ((1.0/feather) * (t + 1))
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_upscale(d - 2, overlap[d - 2]))
|
||||
for t in range(feather):
|
||||
a = (t + 1) / feather
|
||||
mask.narrow(d, t, 1).mul_(a)
|
||||
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||
|
||||
o = out
|
||||
o_d = out_div
|
||||
@@ -748,8 +808,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
|
||||
o += ps * mask
|
||||
o_d += mask
|
||||
o.add_(ps * mask)
|
||||
o_d.add_(mask)
|
||||
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
@@ -788,3 +848,24 @@ class ProgressBar:
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
def reshape_mask(input_mask, output_shape):
|
||||
dims = len(output_shape) - 2
|
||||
|
||||
if dims == 1:
|
||||
scale_mode = "linear"
|
||||
|
||||
if dims == 2:
|
||||
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "bilinear"
|
||||
|
||||
if dims == 3:
|
||||
if len(input_mask.shape) < 5:
|
||||
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "trilinear"
|
||||
|
||||
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
|
||||
if mask.shape[1] < output_shape[1]:
|
||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
||||
return mask
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
|
||||
import nodes
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def include_unique_id_in_input(class_type: str) -> bool:
|
||||
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
|
||||
class CacheKeySet:
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
@@ -98,7 +108,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
|
||||
@@ -99,30 +99,44 @@ class TopologicalSort:
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
if not self.is_cached(from_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
if unique_id in self.pendingNodes:
|
||||
return
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
node_ids = [node_unique_id]
|
||||
links = []
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||
continue
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if include_lazy or not is_lazy:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
while len(node_ids) > 0:
|
||||
unique_id = node_ids.pop()
|
||||
if unique_id in self.pendingNodes:
|
||||
continue
|
||||
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||
continue
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return False
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if self.output_cache.get(from_node_id) is not None:
|
||||
# Nothing to do
|
||||
return
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
|
||||
39
comfy_execution/validation.py
Normal file
39
comfy_execution/validation.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def validate_node_input(
|
||||
received_type: str, input_type: str, strict: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
received_type and input_type are both strings of the form "T1,T2,...".
|
||||
|
||||
If strict is True, the input_type must contain the received_type.
|
||||
For example, if received_type is "STRING" and input_type is "STRING,INT",
|
||||
this will return True. But if received_type is "STRING,INT" and input_type is
|
||||
"INT", this will return False.
|
||||
|
||||
If strict is False, the input_type must have overlap with the received_type.
|
||||
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
|
||||
this will return True.
|
||||
|
||||
Supports pre-union type extension behaviour of ``__ne__`` overrides.
|
||||
"""
|
||||
# If the types are exactly the same, we can return immediately
|
||||
# Use pre-union behaviour: inverse of `__ne__`
|
||||
if not received_type != input_type:
|
||||
return True
|
||||
|
||||
# Not equal, and not strings
|
||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||
return False
|
||||
|
||||
# Split the type strings into sets for comparison
|
||||
received_types = set(t.strip() for t in received_type.split(","))
|
||||
input_types = set(t.strip() for t in input_type.split(","))
|
||||
|
||||
if strict:
|
||||
# In strict mode, all received types must be in the input types
|
||||
return received_types.issubset(input_types)
|
||||
else:
|
||||
# In non-strict mode, there must be at least one type in common
|
||||
return len(received_types.intersection(input_types)) > 0
|
||||
@@ -16,14 +16,15 @@ class EmptyLatentAudio:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds):
|
||||
batch_size = 1
|
||||
def generate(self, seconds, batch_size):
|
||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||
return ({"samples":latent, "type": "audio"}, )
|
||||
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
|
||||
|
||||
def decode(self, vae, samples):
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
|
||||
|
||||
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
|
||||
}
|
||||
|
||||
class LoadAudio:
|
||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [
|
||||
f for f in os.listdir(input_dir)
|
||||
if (os.path.isfile(os.path.join(input_dir, f))
|
||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
||||
)
|
||||
]
|
||||
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
@@ -17,8 +17,7 @@ class CLIPTextEncodeSDXLRefiner:
|
||||
|
||||
def encode(self, clip, ascore, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
|
||||
|
||||
class CLIPTextEncodeSDXL:
|
||||
@classmethod
|
||||
@@ -47,8 +46,7 @@ class CLIPTextEncodeSDXL:
|
||||
tokens["l"] += empty["l"]
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||
import nodes
|
||||
import comfy.utils
|
||||
|
||||
class SetUnionControlNetType:
|
||||
@classmethod
|
||||
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
|
||||
|
||||
return (control_net,)
|
||||
|
||||
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"control_net": ("CONTROL_NET", ),
|
||||
"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
|
||||
FUNCTION = "apply_inpaint_controlnet"
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||
extra_concat = []
|
||||
if control_net.concat_mask:
|
||||
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
||||
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||
extra_concat = [mask]
|
||||
|
||||
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SetUnionControlNetType": SetUnionControlNetType,
|
||||
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||
}
|
||||
|
||||
@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
|
||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||
return (sigmas, )
|
||||
|
||||
class LaplaceScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
|
||||
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
|
||||
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
|
||||
return (sigmas, )
|
||||
|
||||
|
||||
class SDTurboScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"KarrasScheduler": KarrasScheduler,
|
||||
"ExponentialScheduler": ExponentialScheduler,
|
||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||
"LaplaceScheduler": LaplaceScheduler,
|
||||
"VPScheduler": VPScheduler,
|
||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||
"SDTurboScheduler": SDTurboScheduler,
|
||||
|
||||
@@ -18,10 +18,7 @@ class CLIPTextEncodeFlux:
|
||||
tokens = clip.tokenize(clip_l)
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
output["guidance"] = guidance
|
||||
return ([[cond, output]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
|
||||
|
||||
class FluxGuidance:
|
||||
@classmethod
|
||||
|
||||
745
comfy_extras/nodes_hooks.py
Normal file
745
comfy_extras/nodes_hooks.py
Normal file
@@ -0,0 +1,745 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Union
|
||||
import torch
|
||||
from collections.abc import Iterable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.sd import CLIP
|
||||
|
||||
import comfy.hooks
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
|
||||
###########################################
|
||||
# Mask, Combine, and Hook Conditioning
|
||||
#------------------------------------------
|
||||
class PairConditioningSetProperties:
|
||||
NodeId = 'PairConditioningSetProperties'
|
||||
NodeName = 'Cond Pair Set Props'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive_NEW": ("CONDITIONING", ),
|
||||
"negative_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, positive_NEW, negative_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class PairConditioningSetPropertiesAndCombine:
|
||||
NodeId = 'PairConditioningSetPropertiesAndCombine'
|
||||
NodeName = 'Cond Pair Set Props Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"positive_NEW": ("CONDITIONING", ),
|
||||
"negative_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class ConditioningSetProperties:
|
||||
NodeId = 'ConditioningSetProperties'
|
||||
NodeName = 'Cond Set Props'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, cond_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
(final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_cond,)
|
||||
|
||||
class ConditioningSetPropertiesAndCombine:
|
||||
NodeId = 'ConditioningSetPropertiesAndCombine'
|
||||
NodeName = 'Cond Set Props Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond": ("CONDITIONING", ),
|
||||
"cond_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, cond, cond_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
(final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_cond,)
|
||||
|
||||
class PairConditioningCombine:
|
||||
NodeId = 'PairConditioningCombine'
|
||||
NodeName = 'Cond Pair Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive_A": ("CONDITIONING",),
|
||||
"negative_A": ("CONDITIONING",),
|
||||
"positive_B": ("CONDITIONING",),
|
||||
"negative_B": ("CONDITIONING",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "combine"
|
||||
|
||||
def combine(self, positive_A, negative_A, positive_B, negative_B):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
|
||||
return (final_positive, final_negative,)
|
||||
|
||||
class PairConditioningSetDefaultAndCombine:
|
||||
NodeId = 'PairConditioningSetDefaultCombine'
|
||||
NodeName = 'Cond Pair Set Default Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"positive_DEFAULT": ("CONDITIONING",),
|
||||
"negative_DEFAULT": ("CONDITIONING",),
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_default_and_combine"
|
||||
|
||||
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
|
||||
hooks: comfy.hooks.HookGroup=None):
|
||||
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
|
||||
hooks=hooks)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class ConditioningSetDefaultAndCombine:
|
||||
NodeId = 'ConditioningSetDefaultCombine'
|
||||
NodeName = 'Cond Set Default Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond": ("CONDITIONING",),
|
||||
"cond_DEFAULT": ("CONDITIONING",),
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_default_and_combine"
|
||||
|
||||
def set_default_and_combine(self, cond, cond_DEFAULT,
|
||||
hooks: comfy.hooks.HookGroup=None):
|
||||
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
|
||||
hooks=hooks)
|
||||
return (final_conditioning,)
|
||||
|
||||
class SetClipHooks:
|
||||
NodeId = 'SetClipHooks'
|
||||
NodeName = 'Set CLIP Hooks'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP",),
|
||||
"apply_to_conds": ("BOOLEAN", {"default": True}),
|
||||
"schedule_clip": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
CATEGORY = "advanced/hooks/clip"
|
||||
FUNCTION = "apply_hooks"
|
||||
|
||||
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||
if hooks is not None:
|
||||
clip = clip.clone()
|
||||
if apply_to_conds:
|
||||
clip.apply_hooks_to_conds = hooks
|
||||
clip.patcher.forced_hooks = hooks.clone()
|
||||
clip.use_clip_schedule = schedule_clip
|
||||
if not clip.use_clip_schedule:
|
||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
NodeId = 'ConditioningTimestepsRange'
|
||||
NodeName = 'Timesteps Range'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
|
||||
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
|
||||
CATEGORY = "advanced/hooks"
|
||||
FUNCTION = "create_range"
|
||||
|
||||
def create_range(self, start_percent: float, end_percent: float):
|
||||
return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
###########################################
|
||||
# Create Hooks
|
||||
#------------------------------------------
|
||||
class CreateHookLora:
|
||||
NodeId = 'CreateHookLora'
|
||||
NodeName = 'Create Hook LoRA'
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook"
|
||||
|
||||
def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
|
||||
if prev_hooks is None:
|
||||
prev_hooks = comfy.hooks.HookGroup()
|
||||
prev_hooks.clone()
|
||||
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (prev_hooks,)
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
else:
|
||||
temp = self.loaded_lora
|
||||
self.loaded_lora = None
|
||||
del temp
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
|
||||
return (prev_hooks.clone_and_combine(hooks),)
|
||||
|
||||
class CreateHookLoraModelOnly(CreateHookLora):
|
||||
NodeId = 'CreateHookLoraModelOnly'
|
||||
NodeName = 'Create Hook LoRA (MO)'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook_model_only"
|
||||
|
||||
def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
|
||||
return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
|
||||
|
||||
class CreateHookModelAsLora:
|
||||
NodeId = 'CreateHookModelAsLora'
|
||||
NodeName = 'Create Hook Model as LoRA'
|
||||
|
||||
def __init__(self):
|
||||
# when not None, will be in following format:
|
||||
# (ckpt_path: str, weights_model: dict, weights_clip: dict)
|
||||
self.loaded_weights = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook"
|
||||
|
||||
def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
|
||||
prev_hooks: comfy.hooks.HookGroup=None):
|
||||
if prev_hooks is None:
|
||||
prev_hooks = comfy.hooks.HookGroup()
|
||||
prev_hooks.clone()
|
||||
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
weights_model = None
|
||||
weights_clip = None
|
||||
if self.loaded_weights is not None:
|
||||
if self.loaded_weights[0] == ckpt_path:
|
||||
weights_model = self.loaded_weights[1]
|
||||
weights_clip = self.loaded_weights[2]
|
||||
else:
|
||||
temp = self.loaded_weights
|
||||
self.loaded_weights = None
|
||||
del temp
|
||||
|
||||
if weights_model is None:
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
|
||||
weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
|
||||
self.loaded_weights = (ckpt_path, weights_model, weights_clip)
|
||||
|
||||
hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
|
||||
strength_model=strength_model, strength_clip=strength_clip)
|
||||
return (prev_hooks.clone_and_combine(hooks),)
|
||||
|
||||
class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
|
||||
NodeId = 'CreateHookModelAsLoraModelOnly'
|
||||
NodeName = 'Create Hook Model as LoRA (MO)'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook_model_only"
|
||||
|
||||
def create_hook_model_only(self, ckpt_name: str, strength_model: float,
|
||||
prev_hooks: comfy.hooks.HookGroup=None):
|
||||
return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
###########################################
|
||||
# Schedule Hooks
|
||||
#------------------------------------------
|
||||
class SetHookKeyframes:
|
||||
NodeId = 'SetHookKeyframes'
|
||||
NodeName = 'Set Hook Keyframes'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"hooks": ("HOOKS",),
|
||||
},
|
||||
"optional": {
|
||||
"hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "set_hook_keyframes"
|
||||
|
||||
def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if hook_kf is not None:
|
||||
hooks = hooks.clone()
|
||||
hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
|
||||
return (hooks,)
|
||||
|
||||
class CreateHookKeyframe:
|
||||
NodeId = 'CreateHookKeyframe'
|
||||
NodeName = 'Create Hook Keyframe'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframe"
|
||||
|
||||
def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
|
||||
prev_hook_kf.add(keyframe)
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesInterpolated:
|
||||
NodeId = 'CreateHookKeyframesInterpolated'
|
||||
NodeName = 'Create Hook Keyframes Interp.'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
||||
"strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
||||
"interpolation": (comfy.hooks.InterpolationMethod._LIST, ),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}),
|
||||
"print_keyframes": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframes"
|
||||
|
||||
def create_hook_keyframes(self, strength_start: float, strength_end: float, interpolation: str,
|
||||
start_percent: float, end_percent: float, keyframes_count: int,
|
||||
print_keyframes=False, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=keyframes_count,
|
||||
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||
strengths = comfy.hooks.InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=keyframes_count, method=interpolation)
|
||||
|
||||
is_first = True
|
||||
for percent, strength in zip(percents, strengths):
|
||||
guarantee_steps = 0
|
||||
if is_first:
|
||||
guarantee_steps = 1
|
||||
is_first = False
|
||||
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||
if print_keyframes:
|
||||
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesFromFloats:
|
||||
NodeId = 'CreateHookKeyframesFromFloats'
|
||||
NodeName = 'Create Hook Keyframes From Floats'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"print_keyframes": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframes"
|
||||
|
||||
def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
|
||||
start_percent: float, end_percent: float,
|
||||
prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
if type(floats_strength) in (float, int):
|
||||
floats_strength = [float(floats_strength)]
|
||||
elif isinstance(floats_strength, Iterable):
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
|
||||
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
|
||||
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||
|
||||
is_first = True
|
||||
for percent, strength in zip(percents, floats_strength):
|
||||
guarantee_steps = 0
|
||||
if is_first:
|
||||
guarantee_steps = 1
|
||||
is_first = False
|
||||
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||
if print_keyframes:
|
||||
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||
return (prev_hook_kf,)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
class SetModelHooksOnCond:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"hooks": ("HOOKS",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/manual"
|
||||
FUNCTION = "attach_hook"
|
||||
|
||||
def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
|
||||
return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
|
||||
|
||||
|
||||
###########################################
|
||||
# Combine Hooks
|
||||
#------------------------------------------
|
||||
class CombineHooks:
|
||||
NodeId = 'CombineHooks2'
|
||||
NodeName = 'Combine Hooks [2]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
|
||||
class CombineHooksFour:
|
||||
NodeId = 'CombineHooks4'
|
||||
NodeName = 'Combine Hooks [4]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
"hooks_C": ("HOOKS",),
|
||||
"hooks_D": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None,
|
||||
hooks_C: comfy.hooks.HookGroup=None,
|
||||
hooks_D: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
|
||||
class CombineHooksEight:
|
||||
NodeId = 'CombineHooks8'
|
||||
NodeName = 'Combine Hooks [8]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
"hooks_C": ("HOOKS",),
|
||||
"hooks_D": ("HOOKS",),
|
||||
"hooks_E": ("HOOKS",),
|
||||
"hooks_F": ("HOOKS",),
|
||||
"hooks_G": ("HOOKS",),
|
||||
"hooks_H": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None,
|
||||
hooks_C: comfy.hooks.HookGroup=None,
|
||||
hooks_D: comfy.hooks.HookGroup=None,
|
||||
hooks_E: comfy.hooks.HookGroup=None,
|
||||
hooks_F: comfy.hooks.HookGroup=None,
|
||||
hooks_G: comfy.hooks.HookGroup=None,
|
||||
hooks_H: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
node_list = [
|
||||
# Create
|
||||
CreateHookLora,
|
||||
CreateHookLoraModelOnly,
|
||||
CreateHookModelAsLora,
|
||||
CreateHookModelAsLoraModelOnly,
|
||||
# Scheduling
|
||||
SetHookKeyframes,
|
||||
CreateHookKeyframe,
|
||||
CreateHookKeyframesInterpolated,
|
||||
CreateHookKeyframesFromFloats,
|
||||
# Combine
|
||||
CombineHooks,
|
||||
CombineHooksFour,
|
||||
CombineHooksEight,
|
||||
# Attach
|
||||
ConditioningSetProperties,
|
||||
ConditioningSetPropertiesAndCombine,
|
||||
PairConditioningSetProperties,
|
||||
PairConditioningSetPropertiesAndCombine,
|
||||
ConditioningSetDefaultAndCombine,
|
||||
PairConditioningSetDefaultAndCombine,
|
||||
PairConditioningCombine,
|
||||
SetClipHooks,
|
||||
# Other
|
||||
ConditioningTimestepsRange,
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
for node in node_list:
|
||||
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
||||
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
||||
@@ -15,9 +15,7 @@ class CLIPTextEncodeHunyuanDiT:
|
||||
tokens = clip.tokenize(bert)
|
||||
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
||||
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -107,7 +107,7 @@ class HypernetworkLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||
model_hypernetwork = model.clone()
|
||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||
if patch is not None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import comfy.utils
|
||||
import comfy_extras.nodes_post_processing
|
||||
import torch
|
||||
|
||||
def reshape_latent_to(target_shape, latent):
|
||||
@@ -145,6 +146,131 @@ class LatentBatchSeedBehavior:
|
||||
|
||||
return (samples_out,)
|
||||
|
||||
class LatentApplyOperation:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"operation": ("LATENT_OPERATION",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, samples, operation):
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = operation(latent=s1)
|
||||
return (samples_out,)
|
||||
|
||||
class LatentApplyOperationCFG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"operation": ("LATENT_OPERATION",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, operation):
|
||||
m = model.clone()
|
||||
|
||||
def pre_cfg_function(args):
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) == 2:
|
||||
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
|
||||
else:
|
||||
conds_out[0] = operation(latent=conds_out[0])
|
||||
return conds_out
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m, )
|
||||
|
||||
class LatentOperationTonemapReinhard:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, multiplier):
|
||||
def tonemap_reinhard(latent, **kwargs):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
|
||||
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
|
||||
top = (std * 5 + mean) * multiplier
|
||||
|
||||
#reinhard
|
||||
latent_vector_magnitude *= (1.0 / top)
|
||||
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
|
||||
new_magnitude *= top
|
||||
|
||||
return normalized_latent * new_magnitude
|
||||
return (tonemap_reinhard,)
|
||||
|
||||
class LatentOperationSharpen:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"sharpen_radius": ("INT", {
|
||||
"default": 9,
|
||||
"min": 1,
|
||||
"max": 31,
|
||||
"step": 1
|
||||
}),
|
||||
"sigma": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.1,
|
||||
"max": 10.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"default": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 5.0,
|
||||
"step": 0.01
|
||||
}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, sharpen_radius, sigma, alpha):
|
||||
def sharpen(latent, **kwargs):
|
||||
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||
normalized_latent = latent / luminance
|
||||
channels = latent.shape[1]
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
|
||||
center = kernel_size // 2
|
||||
|
||||
kernel *= alpha * -10
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
|
||||
padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||
|
||||
return luminance * sharpened
|
||||
return (sharpen,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
@@ -152,4 +278,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
"LatentBatch": LatentBatch,
|
||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||
"LatentApplyOperation": LatentApplyOperation,
|
||||
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
||||
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
||||
"LatentOperationSharpen": LatentOperationSharpen,
|
||||
}
|
||||
|
||||
119
comfy_extras/nodes_lora_extract.py
Normal file
119
comfy_extras/nodes_lora_extract.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import os
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
def extract_lora(diff, rank):
|
||||
conv2d = (len(diff.shape) == 4)
|
||||
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = diff.size()[0:2]
|
||||
rank = min(rank, in_dim, out_dim)
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
diff = diff.flatten(start_dim=1)
|
||||
else:
|
||||
diff = diff.squeeze()
|
||||
|
||||
|
||||
U, S, Vh = torch.linalg.svd(diff.float())
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
return (U, Vh)
|
||||
|
||||
class LORAType(Enum):
|
||||
STANDARD = 0
|
||||
FULL_DIFF = 1
|
||||
|
||||
LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||
"full_diff": LORAType.FULL_DIFF}
|
||||
|
||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||
|
||||
for k in sd:
|
||||
if k.endswith(".weight"):
|
||||
weight_diff = sd[k]
|
||||
if lora_type == LORAType.STANDARD:
|
||||
if weight_diff.ndim < 2:
|
||||
if bias_diff:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
continue
|
||||
try:
|
||||
out = extract_lora(weight_diff, rank)
|
||||
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||
except:
|
||||
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||
elif lora_type == LORAType.FULL_DIFF:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
|
||||
elif bias_diff and k.endswith(".bias"):
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||
return output_sd
|
||||
|
||||
class LoraSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
||||
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
||||
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||
if model_diff is None and text_encoder_diff is None:
|
||||
return {}
|
||||
|
||||
lora_type = LORA_TYPES.get(lora_type)
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
output_sd = {}
|
||||
if model_diff is not None:
|
||||
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
|
||||
if text_encoder_diff is not None:
|
||||
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraSave": LoraSave
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraSave": "Extract and Save Lora"
|
||||
}
|
||||
181
comfy_extras/nodes_lt.py
Normal file
181
comfy_extras/nodes_lt.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
|
||||
class EmptyLTXVLatentVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/video/ltxv"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent}, )
|
||||
|
||||
|
||||
class LTXVImgToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "generate"
|
||||
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
|
||||
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
latent[:, :, :t.shape[2]] = t
|
||||
return (positive, negative, {"samples": latent}, )
|
||||
|
||||
|
||||
class LTXVConditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def append(self, positive, negative, frame_rate):
|
||||
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class ModelSamplingLTXV:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
},
|
||||
"optional": {"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, max_shift, base_shift, latent=None):
|
||||
m = model.clone()
|
||||
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
shift = (tokens) * mm + b
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingFlux
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
|
||||
class LTXVScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"stretch": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
|
||||
}),
|
||||
"terminal": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
|
||||
"tooltip": "The terminal value of the sigmas after stretching."
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
sigma_shift = (tokens) * mm + b
|
||||
|
||||
power = 1
|
||||
sigmas = torch.where(
|
||||
sigmas != 0,
|
||||
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
||||
0,
|
||||
)
|
||||
|
||||
# Stretch sigmas so that its final value matches the given terminal value.
|
||||
if stretch:
|
||||
non_zero_mask = sigmas != 0
|
||||
non_zero_sigmas = sigmas[non_zero_mask]
|
||||
one_minus_z = 1.0 - non_zero_sigmas
|
||||
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||
sigmas[non_zero_mask] = stretched
|
||||
|
||||
return (sigmas,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
||||
"LTXVImgToVideo": LTXVImgToVideo,
|
||||
"ModelSamplingLTXV": ModelSamplingLTXV,
|
||||
"LTXVConditioning": LTXVConditioning,
|
||||
"LTXVScheduler": LTXVScheduler,
|
||||
}
|
||||
23
comfy_extras/nodes_mochi.py
Normal file
23
comfy_extras/nodes_mochi.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import nodes
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
class EmptyMochiLatentVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples":latent}, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
|
||||
}
|
||||
@@ -26,8 +26,8 @@ class X0(comfy.model_sampling.EPS):
|
||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||
original_timesteps = 50
|
||||
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__(model_config)
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__(model_config, zsnr=zsnr)
|
||||
|
||||
self.skip_steps = self.num_timesteps // self.original_timesteps
|
||||
|
||||
@@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
if zsnr:
|
||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
|
||||
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user