Compare commits
198 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
696672905f | ||
|
|
6c9dbde7de | ||
|
|
ee8abf0cff | ||
|
|
fabf449feb | ||
|
|
cc9cf6d1bd | ||
|
|
1c8286a44b | ||
|
|
1af4a47fd1 | ||
|
|
f2aaa0a475 | ||
|
|
daa1565b93 | ||
|
|
09fdb2b269 | ||
|
|
65a8659182 | ||
|
|
770ab200f2 | ||
|
|
954683d0db | ||
|
|
30c0c81351 | ||
|
|
13b0ff8a6f | ||
|
|
c320801187 | ||
|
|
c0b0cfaeec | ||
|
|
669d9e4c67 | ||
|
|
9ee0a6553a | ||
|
|
5cbb01bc2f | ||
|
|
c3ffbae067 | ||
|
|
d605677b33 | ||
|
|
ce759b7db6 | ||
|
|
52810907e2 | ||
|
|
af8cf79a2d | ||
|
|
66b0961a46 | ||
|
|
754597c8a9 | ||
|
|
915fdb5745 | ||
|
|
5a8a48931a | ||
|
|
8ce2a1052c | ||
|
|
f82314fcfc | ||
|
|
0075c6d096 | ||
|
|
83ca891118 | ||
|
|
f9f9faface | ||
|
|
471cd3eace | ||
|
|
a68bbafddb | ||
|
|
73e3a9e676 | ||
|
|
518c0dc2fe | ||
|
|
ce0542e10b | ||
|
|
8473019d40 | ||
|
|
89f15894dd | ||
|
|
67158994a4 | ||
|
|
7390ff3b1e | ||
|
|
0bedfb26af | ||
|
|
f71cfd2687 | ||
|
|
c695c4af7f | ||
|
|
0dbba9f751 | ||
|
|
f584758271 | ||
|
|
95b7cf9bbe | ||
|
|
191a0d56b4 | ||
|
|
3c60ecd7a8 | ||
|
|
7ae6626723 | ||
|
|
6632365e16 | ||
|
|
ad07796777 | ||
|
|
1b80895285 | ||
|
|
5f9d5a244b | ||
|
|
14eba07acd | ||
|
|
4b2f0d9413 | ||
|
|
25eac1d780 | ||
|
|
e38c94228b | ||
|
|
203942c8b2 | ||
|
|
3c72c89a52 | ||
|
|
614377abd6 | ||
|
|
8dfa0cc552 | ||
|
|
e5ecdfdd2d | ||
|
|
7d29fbf74b | ||
|
|
2c641e64ad | ||
|
|
7d2467e830 | ||
|
|
6f021d8aa0 | ||
|
|
d854ed0bcf | ||
|
|
abcd006b8c | ||
|
|
d985d1d7dc | ||
|
|
d1cdf51e1b | ||
|
|
b4626ab93e | ||
|
|
a9e459c2a4 | ||
|
|
3bb4dec720 | ||
|
|
8733191563 | ||
|
|
83b01f960a | ||
|
|
d72e871cfa | ||
|
|
037c3159b6 | ||
|
|
bdd4a22a2e | ||
|
|
fdf37566ef | ||
|
|
08c8968482 | ||
|
|
479a427a48 | ||
|
|
3a0eeee320 | ||
|
|
447da7ea86 | ||
|
|
9c41bc8d10 | ||
|
|
6ad0ddbae4 | ||
|
|
a55142f904 | ||
|
|
5718ef69bb | ||
|
|
13ecf10a92 | ||
|
|
7a415f47a9 | ||
|
|
89fa2fca24 | ||
|
|
364b69e931 | ||
|
|
dc96a1ae19 | ||
|
|
2d810b081e | ||
|
|
9f7e9f0547 | ||
|
|
a355f38ecc | ||
|
|
38c69080c7 | ||
|
|
70a708d726 | ||
|
|
e7d4782736 | ||
|
|
3326bdfd4e | ||
|
|
68bb885d22 | ||
|
|
ad66f7c7d8 | ||
|
|
de8e8e3b0d | ||
|
|
a1e71cfad1 | ||
|
|
0bfc7cc998 | ||
|
|
7183fd1665 | ||
|
|
254838f23c | ||
|
|
0b7dfa986d | ||
|
|
d514bb38ee | ||
|
|
0849c80e2a | ||
|
|
56e8f5e4fd | ||
|
|
e813abbb2c | ||
|
|
5e68a4ce67 | ||
|
|
ca08597670 | ||
|
|
f48e390032 | ||
|
|
369a6dd2c4 | ||
|
|
b3ce8fb9fd | ||
|
|
cf80d28689 | ||
|
|
6fb44c4b7c | ||
|
|
d2247c1e61 | ||
|
|
cb12ad7049 | ||
|
|
f6b7194f64 | ||
|
|
7c6eb4fb29 | ||
|
|
b962db9952 | ||
|
|
d0b7ab88ba | ||
|
|
405b529545 | ||
|
|
9d720187f1 | ||
|
|
d247bc5a9c | ||
|
|
9f4daca9d9 | ||
|
|
b5d0f2a908 | ||
|
|
e760bf5c40 | ||
|
|
36c83cdbba | ||
|
|
81778a7feb | ||
|
|
bc94662b31 | ||
|
|
9fa8faa44a | ||
|
|
9a7444e39f | ||
|
|
54fca4a218 | ||
|
|
cd4955367e | ||
|
|
8354203d95 | ||
|
|
e0b41243b4 | ||
|
|
619263d4a6 | ||
|
|
e3b0402bb7 | ||
|
|
967867d48c | ||
|
|
cbaac71bf5 | ||
|
|
3ab3516e46 | ||
|
|
9c5fca75f4 | ||
|
|
a5da4d0b3e | ||
|
|
32a60a7bac | ||
|
|
bb52934ba4 | ||
|
|
8aabd7c8c0 | ||
|
|
a09b29ca11 | ||
|
|
9bfee68773 | ||
|
|
ea77750759 | ||
|
|
c27ebeb1c2 | ||
|
|
0c7c98a965 | ||
|
|
dc2eb75b85 | ||
|
|
fa34efe3bd | ||
|
|
5cbaa9e07c | ||
|
|
c7427375ee | ||
|
|
22d1241a50 | ||
|
|
f04229b84d | ||
|
|
f067ad15d1 | ||
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 | ||
|
|
f1c2301697 | ||
|
|
8d31a6632f | ||
|
|
b643eae08b | ||
|
|
baa6b4dc36 | ||
|
|
d4aeefc297 | ||
|
|
587e7ca654 | ||
|
|
c90459eba0 | ||
|
|
04278afb10 | ||
|
|
935ae153e1 | ||
|
|
e91662e784 | ||
|
|
63fafaef45 | ||
|
|
ec28cd9136 | ||
|
|
6eb5d64522 | ||
|
|
10a79e9898 | ||
|
|
ea3f39bd69 | ||
|
|
b33cd61070 | ||
|
|
34eda0f853 | ||
|
|
d31e226650 | ||
|
|
b79fd7d92c | ||
|
|
38c22e631a | ||
|
|
6bbdcd28ae | ||
|
|
ab130001a8 | ||
|
|
ca4b8f30e0 | ||
|
|
70b84058c1 | ||
|
|
2ca8f6e23d | ||
|
|
7985ff88b9 | ||
|
|
c6812947e9 | ||
|
|
9230f65823 | ||
|
|
6ab1e6fd4a | ||
|
|
07dcbc3a3e | ||
|
|
8ae23d8e80 |
@@ -14,7 +14,7 @@ run_cpu.bat
|
|||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED WAY TO UPDATE:
|
RECOMMENDED WAY TO UPDATE:
|
||||||
|
|||||||
2
.github/workflows/pullrequest-ci-run.yml
vendored
2
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
6
.github/workflows/stable-release.yml
vendored
6
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
|||||||
description: 'CUDA version'
|
description: 'CUDA version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "121"
|
default: "124"
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'Python minor version'
|
description: 'Python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'Python patch version'
|
description: 'Python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: 'Close stale issues'
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
# Run daily at 430 am PT
|
||||||
|
- cron: '30 11 * * *'
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||||
|
days-before-stale: 30
|
||||||
|
days-before-close: 7
|
||||||
|
stale-issue-label: 'Stale'
|
||||||
|
only-labels: 'User Support'
|
||||||
|
exempt-all-assignees: true
|
||||||
|
exempt-all-milestones: true
|
||||||
4
.github/workflows/test-ci.yml
vendored
4
.github/workflows/test-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
torch_version: ["nightly"]
|
torch_version: ["nightly"]
|
||||||
include:
|
include:
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
30
.github/workflows/test-unit.yml
vendored
Normal file
30
.github/workflows/test-unit.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: Unit Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
continue-on-error: true
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Run Unit Tests
|
||||||
|
run: |
|
||||||
|
pip install -r tests-unit/requirements.txt
|
||||||
|
python -m pytest tests-unit
|
||||||
@@ -12,7 +12,7 @@ on:
|
|||||||
description: 'extra dependencies'
|
description: 'extra dependencies'
|
||||||
required: false
|
required: false
|
||||||
type: string
|
type: string
|
||||||
default: "\"numpy<2\""
|
default: ""
|
||||||
cu:
|
cu:
|
||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
@@ -23,13 +23,13 @@ on:
|
|||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ on:
|
|||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ extra_model_paths.yaml
|
|||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv/
|
venv/
|
||||||
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
!/web/extensions/core/
|
!/web/extensions/core/
|
||||||
|
|||||||
17
README.md
17
README.md
@@ -1,7 +1,7 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
# ComfyUI
|
# ComfyUI
|
||||||
**The most powerful and modular stable diffusion GUI and backend.**
|
**The most powerful and modular diffusion model GUI and backend.**
|
||||||
|
|
||||||
|
|
||||||
[![Website][website-shield]][website-url]
|
[![Website][website-shield]][website-url]
|
||||||
@@ -40,6 +40,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||||
@@ -94,6 +95,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| Alt + `+` | Canvas Zoom in |
|
| Alt + `+` | Canvas Zoom in |
|
||||||
| Alt + `-` | Canvas Zoom out |
|
| Alt + `-` | Canvas Zoom out |
|
||||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||||
|
| P | Pin/Unpin selected nodes |
|
||||||
|
| Ctrl + G | Group selected nodes |
|
||||||
| Q | Toggle visibility of the queue |
|
| Q | Toggle visibility of the queue |
|
||||||
| H | Toggle visibility of history |
|
| H | Toggle visibility of history |
|
||||||
| R | Refresh graph |
|
| R | Refresh graph |
|
||||||
@@ -125,6 +128,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
|
|||||||
|
|
||||||
## Manual Install (Windows, Linux)
|
## Manual Install (Windows, Linux)
|
||||||
|
|
||||||
|
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|
||||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||||
@@ -135,17 +140,17 @@ Put your VAE in: models/vae
|
|||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
@@ -230,7 +235,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||||
|
|
||||||
## How to use TLS/SSL?
|
## How to use TLS/SSL?
|
||||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from folder_paths import models_dir, user_directory, output_directory
|
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||||
from api_server.services.file_service import FileService
|
from api_server.services.file_service import FileService
|
||||||
|
import app.logger
|
||||||
|
|
||||||
class InternalRoutes:
|
class InternalRoutes:
|
||||||
'''
|
'''
|
||||||
@@ -31,6 +32,16 @@ class InternalRoutes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
@self.routes.get('/logs')
|
||||||
|
async def get_logs(request):
|
||||||
|
return web.json_response(app.logger.get_logs())
|
||||||
|
|
||||||
|
@self.routes.get('/folder_paths')
|
||||||
|
async def get_folder_paths(request):
|
||||||
|
response = {}
|
||||||
|
for key in folder_names_and_paths:
|
||||||
|
response[key] = folder_names_and_paths[key][0]
|
||||||
|
return web.json_response(response)
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import zipfile
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
@@ -132,12 +132,13 @@ class FrontendManager:
|
|||||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Initializes the frontend for the specified version.
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version_string (str): The version string.
|
version_string (str): The version string.
|
||||||
|
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the initialized frontend.
|
str: The path to the initialized frontend.
|
||||||
@@ -150,7 +151,16 @@ class FrontendManager:
|
|||||||
return cls.DEFAULT_FRONTEND_PATH
|
return cls.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
provider = FrontEndProvider(repo_owner, repo_name)
|
|
||||||
|
if version.startswith("v"):
|
||||||
|
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||||
|
if os.path.exists(expected_path):
|
||||||
|
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||||
|
return expected_path
|
||||||
|
|
||||||
|
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||||
|
|
||||||
|
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||||
release = provider.get_release(version)
|
release = provider.get_release(version)
|
||||||
|
|
||||||
semantic_version = release["tag_name"].lstrip("v")
|
semantic_version = release["tag_name"].lstrip("v")
|
||||||
@@ -158,15 +168,21 @@ class FrontendManager:
|
|||||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||||
)
|
)
|
||||||
if not os.path.exists(web_root):
|
if not os.path.exists(web_root):
|
||||||
os.makedirs(web_root, exist_ok=True)
|
try:
|
||||||
logging.info(
|
os.makedirs(web_root, exist_ok=True)
|
||||||
"Downloading frontend(%s) version(%s) to (%s)",
|
logging.info(
|
||||||
provider.folder_name,
|
"Downloading frontend(%s) version(%s) to (%s)",
|
||||||
semantic_version,
|
provider.folder_name,
|
||||||
web_root,
|
semantic_version,
|
||||||
)
|
web_root,
|
||||||
logging.debug(release)
|
)
|
||||||
download_release_asset_zip(release, destination_path=web_root)
|
logging.debug(release)
|
||||||
|
download_release_asset_zip(release, destination_path=web_root)
|
||||||
|
finally:
|
||||||
|
# Clean up the directory if it is empty, i.e. the download failed
|
||||||
|
if not os.listdir(web_root):
|
||||||
|
os.rmdir(web_root)
|
||||||
|
|
||||||
return web_root
|
return web_root
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
31
app/logger.py
Normal file
31
app/logger.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
from logging.handlers import MemoryHandler
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
logs = None
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logs():
|
||||||
|
return "\n".join([formatter.format(x) for x in logs])
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||||
|
global logs
|
||||||
|
if logs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup default global logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# Create a memory handler with a deque as its buffer
|
||||||
|
logs = deque(maxlen=capacity)
|
||||||
|
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||||
|
memory_handler.buffer = logs
|
||||||
|
memory_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(memory_handler)
|
||||||
@@ -5,17 +5,17 @@ import uuid
|
|||||||
import glob
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from urllib import parse
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from folder_paths import user_directory
|
import folder_paths
|
||||||
from .app_settings import AppSettings
|
from .app_settings import AppSettings
|
||||||
|
|
||||||
default_user = "default"
|
default_user = "default"
|
||||||
users_file = os.path.join(user_directory, "users.json")
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager():
|
class UserManager():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
self.settings = AppSettings(self)
|
self.settings = AppSettings(self)
|
||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
@@ -25,14 +25,17 @@ class UserManager():
|
|||||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(users_file):
|
if os.path.isfile(self.get_users_file()):
|
||||||
with open(users_file) as f:
|
with open(self.get_users_file()) as f:
|
||||||
self.users = json.load(f)
|
self.users = json.load(f)
|
||||||
else:
|
else:
|
||||||
self.users = {}
|
self.users = {}
|
||||||
else:
|
else:
|
||||||
self.users = {"default": "default"}
|
self.users = {"default": "default"}
|
||||||
|
|
||||||
|
def get_users_file(self):
|
||||||
|
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||||
|
|
||||||
def get_request_user_id(self, request):
|
def get_request_user_id(self, request):
|
||||||
user = "default"
|
user = "default"
|
||||||
if args.multi_user and "comfy-user" in request.headers:
|
if args.multi_user and "comfy-user" in request.headers:
|
||||||
@@ -44,7 +47,7 @@ class UserManager():
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
if type == "userdata":
|
if type == "userdata":
|
||||||
root_dir = user_directory
|
root_dir = user_directory
|
||||||
@@ -59,6 +62,10 @@ class UserManager():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if file is not None:
|
if file is not None:
|
||||||
|
# Check if filename is url encoded
|
||||||
|
if "%" in file:
|
||||||
|
file = parse.unquote(file)
|
||||||
|
|
||||||
# prevent leaving /{type}/{user}
|
# prevent leaving /{type}/{user}
|
||||||
path = os.path.abspath(os.path.join(user_root, file))
|
path = os.path.abspath(os.path.join(user_root, file))
|
||||||
if os.path.commonpath((user_root, path)) != user_root:
|
if os.path.commonpath((user_root, path)) != user_root:
|
||||||
@@ -80,8 +87,7 @@ class UserManager():
|
|||||||
|
|
||||||
self.users[user_id] = name
|
self.users[user_id] = name
|
||||||
|
|
||||||
global users_file
|
with open(self.get_users_file(), "w") as f:
|
||||||
with open(users_file, "w") as f:
|
|
||||||
json.dump(self.users, f)
|
json.dump(self.users, f)
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
@@ -112,25 +118,69 @@ class UserManager():
|
|||||||
|
|
||||||
@routes.get("/userdata")
|
@routes.get("/userdata")
|
||||||
async def listuserdata(request):
|
async def listuserdata(request):
|
||||||
|
"""
|
||||||
|
List user data files in a specified directory.
|
||||||
|
|
||||||
|
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||||
|
full file information, and path splitting.
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
- dir (required): The directory to list files from.
|
||||||
|
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||||
|
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||||
|
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- 400: If 'dir' parameter is missing.
|
||||||
|
- 403: If the requested path is not allowed.
|
||||||
|
- 404: If the requested directory does not exist.
|
||||||
|
- 200: JSON response with the list of files or file information.
|
||||||
|
|
||||||
|
The response format depends on the query parameters:
|
||||||
|
- Default: List of relative file paths.
|
||||||
|
- full_info=true: List of dictionaries with file details.
|
||||||
|
- split=true (and full_info=false): List of lists, each containing path components.
|
||||||
|
"""
|
||||||
directory = request.rel_url.query.get('dir', '')
|
directory = request.rel_url.query.get('dir', '')
|
||||||
if not directory:
|
if not directory:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400, text="Directory not provided")
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, directory)
|
path = self.get_request_user_filepath(request, directory)
|
||||||
if not path:
|
if not path:
|
||||||
return web.Response(status=403)
|
return web.Response(status=403, text="Invalid directory")
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404, text="Directory not found")
|
||||||
|
|
||||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||||
results = glob.glob(os.path.join(
|
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||||
glob.escape(path), '**/*'), recursive=recurse)
|
|
||||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
# Use different patterns based on whether we're recursing or not
|
||||||
|
if recurse:
|
||||||
|
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||||
|
else:
|
||||||
|
pattern = os.path.join(glob.escape(path), '*')
|
||||||
|
|
||||||
|
results = glob.glob(pattern, recursive=recurse)
|
||||||
|
|
||||||
|
if full_info:
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||||
|
'size': os.path.getsize(x),
|
||||||
|
'modified': os.path.getmtime(x)
|
||||||
|
} for x in results if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
results = [
|
||||||
|
os.path.relpath(x, path).replace(os.sep, '/')
|
||||||
|
for x in results
|
||||||
|
if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
|
||||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||||
if split_path:
|
if split_path and not full_info:
|
||||||
results = [[x] + x.split(os.sep) for x in results]
|
results = [[x] + x.split('/') for x in results]
|
||||||
|
|
||||||
return web.json_response(results)
|
return web.json_response(results)
|
||||||
|
|
||||||
@@ -138,14 +188,14 @@ class UserManager():
|
|||||||
file = request.match_info.get(param, None)
|
file = request.match_info.get(param, None)
|
||||||
if not file:
|
if not file:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, file)
|
path = self.get_request_user_filepath(request, file)
|
||||||
if not path:
|
if not path:
|
||||||
return web.Response(status=403)
|
return web.Response(status=403)
|
||||||
|
|
||||||
if check_exists and not os.path.exists(path):
|
if check_exists and not os.path.exists(path):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
@routes.get("/userdata/{file}")
|
@routes.get("/userdata/{file}")
|
||||||
@@ -153,7 +203,7 @@ class UserManager():
|
|||||||
path = get_user_data_path(request, check_exists=True)
|
path = get_user_data_path(request, check_exists=True)
|
||||||
if not isinstance(path, str):
|
if not isinstance(path, str):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
return web.FileResponse(path)
|
return web.FileResponse(path)
|
||||||
|
|
||||||
@routes.post("/userdata/{file}")
|
@routes.post("/userdata/{file}")
|
||||||
@@ -161,7 +211,7 @@ class UserManager():
|
|||||||
path = get_user_data_path(request)
|
path = get_user_data_path(request)
|
||||||
if not isinstance(path, str):
|
if not isinstance(path, str):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
overwrite = request.query["overwrite"] != "false"
|
overwrite = request.query["overwrite"] != "false"
|
||||||
if not overwrite and os.path.exists(path):
|
if not overwrite and os.path.exists(path):
|
||||||
return web.Response(status=409)
|
return web.Response(status=409)
|
||||||
@@ -170,7 +220,7 @@ class UserManager():
|
|||||||
|
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
f.write(body)
|
f.write(body)
|
||||||
|
|
||||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||||
return web.json_response(resp)
|
return web.json_response(resp)
|
||||||
|
|
||||||
@@ -181,7 +231,7 @@ class UserManager():
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
return web.Response(status=204)
|
return web.Response(status=204)
|
||||||
|
|
||||||
@routes.post("/userdata/{file}/move/{dest}")
|
@routes.post("/userdata/{file}/move/{dest}")
|
||||||
@@ -189,17 +239,17 @@ class UserManager():
|
|||||||
source = get_user_data_path(request, check_exists=True)
|
source = get_user_data_path(request, check_exists=True)
|
||||||
if not isinstance(source, str):
|
if not isinstance(source, str):
|
||||||
return source
|
return source
|
||||||
|
|
||||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||||
if not isinstance(source, str):
|
if not isinstance(source, str):
|
||||||
return dest
|
return dest
|
||||||
|
|
||||||
overwrite = request.query["overwrite"] != "false"
|
overwrite = request.query["overwrite"] != "false"
|
||||||
if not overwrite and os.path.exists(dest):
|
if not overwrite and os.path.exists(dest):
|
||||||
return web.Response(status=409)
|
return web.Response(status=409)
|
||||||
|
|
||||||
print(f"moving '{source}' -> '{dest}'")
|
print(f"moving '{source}' -> '{dest}'")
|
||||||
shutil.move(source, dest)
|
shutil.move(source, dest)
|
||||||
|
|
||||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||||
return web.json_response(resp)
|
return web.json_response(resp)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
|
control_latent_channels = None,
|
||||||
dtype = None,
|
dtype = None,
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
for _ in range(len(self.joint_blocks)):
|
for _ in range(len(self.joint_blocks)):
|
||||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
|
||||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||||
None,
|
None,
|
||||||
self.patch_size,
|
self.patch_size,
|
||||||
self.in_channels,
|
control_latent_channels,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
strict_img_size=False,
|
strict_img_size=False,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||||
@@ -92,6 +92,8 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|
||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
@@ -134,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
|||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
@@ -169,6 +171,8 @@ parser.add_argument(
|
|||||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
@@ -179,10 +183,3 @@ if args.windows_standalone_build:
|
|||||||
|
|
||||||
if args.disable_auto_launch:
|
if args.disable_auto_launch:
|
||||||
args.auto_launch = False
|
args.auto_launch = False
|
||||||
|
|
||||||
import logging
|
|
||||||
logging_level = logging.INFO
|
|
||||||
if args.verbose:
|
|
||||||
logging_level = logging.DEBUG
|
|
||||||
|
|
||||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
|
||||||
|
|||||||
@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if k not in u:
|
if k not in u:
|
||||||
t = sd.pop(k)
|
sd.pop(k)
|
||||||
del t
|
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load(ckpt_path):
|
def load(ckpt_path):
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ import comfy.t2i_adapter.adapter
|
|||||||
import comfy.ldm.cascade.controlnet
|
import comfy.ldm.cascade.controlnet
|
||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet_xlabs
|
import comfy.ldm.flux.controlnet
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
@@ -60,7 +60,7 @@ class StrengthType(Enum):
|
|||||||
LINEAR_UP = 2
|
LINEAR_UP = 2
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self, device=None):
|
def __init__(self):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
@@ -72,20 +72,24 @@ class ControlBase:
|
|||||||
self.compression_ratio = 8
|
self.compression_ratio = 8
|
||||||
self.upscale_algorithm = 'nearest-exact'
|
self.upscale_algorithm = 'nearest-exact'
|
||||||
self.extra_args = {}
|
self.extra_args = {}
|
||||||
|
|
||||||
if device is None:
|
|
||||||
device = comfy.model_management.get_torch_device()
|
|
||||||
self.device = device
|
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
|
self.concat_mask = False
|
||||||
|
self.extra_concat_orig = []
|
||||||
|
self.extra_concat = None
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.timestep_percent_range = timestep_percent_range
|
self.timestep_percent_range = timestep_percent_range
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
|
if vae is None:
|
||||||
|
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
|
self.extra_concat_orig = extra_concat.copy()
|
||||||
|
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||||
|
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
@@ -100,9 +104,9 @@ class ControlBase:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
self.cond_hint = None
|
||||||
self.cond_hint = None
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
@@ -123,6 +127,8 @@ class ControlBase:
|
|||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
c.extra_conds = self.extra_conds.copy()
|
c.extra_conds = self.extra_conds.copy()
|
||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
|
c.concat_mask = self.concat_mask
|
||||||
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@@ -148,7 +154,7 @@ class ControlBase:
|
|||||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||||
x *= (self.strength ** float(len(control_output) - i))
|
x *= (self.strength ** float(len(control_output) - i))
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if output_dtype is not None and x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
out[key].append(x)
|
out[key].append(x)
|
||||||
@@ -175,8 +181,8 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||||
super().__init__(device)
|
super().__init__()
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
if control_model is not None:
|
if control_model is not None:
|
||||||
@@ -189,6 +195,7 @@ class ControlNet(ControlBase):
|
|||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
self.extra_conds += extra_conds
|
self.extra_conds += extra_conds
|
||||||
self.strength_type = strength_type
|
self.strength_type = strength_type
|
||||||
|
self.concat_mask = concat_mask
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@@ -206,7 +213,6 @@ class ControlNet(ControlBase):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
@@ -214,6 +220,9 @@ class ControlNet(ControlBase):
|
|||||||
compression_ratio = self.compression_ratio
|
compression_ratio = self.compression_ratio
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
compression_ratio *= self.vae.downscale_ratio
|
compression_ratio *= self.vae.downscale_ratio
|
||||||
|
else:
|
||||||
|
if self.latent_format is not None:
|
||||||
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
@@ -221,7 +230,15 @@ class ControlNet(ControlBase):
|
|||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
if len(self.extra_concat_orig) > 0:
|
||||||
|
to_concat = []
|
||||||
|
for c in self.extra_concat_orig:
|
||||||
|
c = c.to(self.cond_hint.device)
|
||||||
|
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||||
|
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||||
|
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||||
|
|
||||||
|
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
@@ -236,7 +253,7 @@ class ControlNet(ControlBase):
|
|||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype=None)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
@@ -320,8 +337,8 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||||
ControlBase.__init__(self, device)
|
ControlBase.__init__(self)
|
||||||
self.control_weights = control_weights
|
self.control_weights = control_weights
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
self.extra_conds += ["y"]
|
self.extra_conds += ["y"]
|
||||||
@@ -377,21 +394,28 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def controlnet_config(sd):
|
def controlnet_config(sd, model_options={}):
|
||||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
controlnet_config = model_config.unet_config
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
@@ -403,26 +427,31 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
return control_model
|
return control_model
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
concat_mask = False
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||||
|
if control_latent_channels == 17: #inpaint controlnet
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data):
|
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SDXL()
|
latent_format = comfy.latent_formats.SDXL()
|
||||||
@@ -430,22 +459,49 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||||
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
num_union_modes = 0
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
union_cnet = "controlnet_mode_embedder.weight"
|
||||||
|
if union_cnet in new_sd:
|
||||||
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||||
|
concat_mask = False
|
||||||
|
if control_latent_channels == 17:
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Flux()
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def convert_mistoline(sd):
|
||||||
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||||
|
controlnet_data = state_dict
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data)
|
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
supported_inference_dtypes = None
|
supported_inference_dtypes = None
|
||||||
@@ -500,11 +556,15 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||||
else:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@@ -516,26 +576,38 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
elif key in controlnet_data:
|
elif key in controlnet_data:
|
||||||
prefix = ""
|
prefix = ""
|
||||||
else:
|
else:
|
||||||
net = load_t2i_adapter(controlnet_data)
|
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||||
if net is None:
|
if net is None:
|
||||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
logging.error("error could not detect control model type.")
|
||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
controlnet_config = model_config.unet_config
|
controlnet_config = model_config.unet_config
|
||||||
|
|
||||||
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||||
|
|
||||||
|
if supported_inference_dtypes is None:
|
||||||
|
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||||
|
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
if supported_inference_dtypes is None:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
|
||||||
else:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
operations = model_options.get("custom_operations", None)
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
controlnet_config["operations"] = operations
|
||||||
controlnet_config["dtype"] = unet_dtype
|
controlnet_config["dtype"] = unet_dtype
|
||||||
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
@@ -569,22 +641,32 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
|
||||||
global_average_pooling = True
|
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
if "global_average_pooling" not in model_options:
|
||||||
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
model_options["global_average_pooling"] = True
|
||||||
|
|
||||||
|
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||||
|
if cnet is None:
|
||||||
|
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||||
|
return cnet
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||||
super().__init__(device)
|
super().__init__()
|
||||||
self.t2i_model = t2i_model
|
self.t2i_model = t2i_model
|
||||||
self.channels_in = channels_in
|
self.channels_in = channels_in
|
||||||
self.control_input = None
|
self.control_input = None
|
||||||
self.compression_ratio = compression_ratio
|
self.compression_ratio = compression_ratio
|
||||||
self.upscale_algorithm = upscale_algorithm
|
self.upscale_algorithm = upscale_algorithm
|
||||||
|
if device is None:
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def scale_image_to(self, width, height):
|
def scale_image_to(self, width, height):
|
||||||
unshuffle_amount = self.t2i_model.unshuffle_amount
|
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||||
@@ -632,7 +714,7 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class NoiseScheduleVP:
|
|||||||
continuous_beta_0=0.1,
|
continuous_beta_0=0.1,
|
||||||
continuous_beta_1=20.,
|
continuous_beta_1=20.,
|
||||||
):
|
):
|
||||||
"""Create a wrapper class for the forward SDE (VP type).
|
r"""Create a wrapper class for the forward SDE (VP type).
|
||||||
|
|
||||||
***
|
***
|
||||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||||
|
|||||||
@@ -1,7 +1,17 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||||
|
mantissa_scaled = torch.where(
|
||||||
|
normal_mask,
|
||||||
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||||
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||||
|
)
|
||||||
|
|
||||||
|
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||||
|
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||||
|
|
||||||
#Not 100% sure about this
|
#Not 100% sure about this
|
||||||
def manual_stochastic_round_to_float8(x, dtype):
|
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn:
|
||||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||||
elif dtype == torch.float8_e5m2:
|
elif dtype == torch.float8_e5m2:
|
||||||
@@ -9,44 +19,35 @@ def manual_stochastic_round_to_float8(x, dtype):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported dtype")
|
raise ValueError("Unsupported dtype")
|
||||||
|
|
||||||
|
x = x.half()
|
||||||
sign = torch.sign(x)
|
sign = torch.sign(x)
|
||||||
abs_x = x.abs()
|
abs_x = x.abs()
|
||||||
|
sign = torch.where(abs_x == 0, 0, sign)
|
||||||
|
|
||||||
# Combine exponent calculation and clamping
|
# Combine exponent calculation and clamping
|
||||||
exponent = torch.clamp(
|
exponent = torch.clamp(
|
||||||
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
|
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||||
0, 2**EXPONENT_BITS - 1
|
0, 2**EXPONENT_BITS - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine mantissa calculation and rounding
|
# Combine mantissa calculation and rounding
|
||||||
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
|
||||||
# zero_mask = (abs_x == 0)
|
|
||||||
# subnormal_mask = (exponent == 0) & (abs_x != 0)
|
|
||||||
normal_mask = ~(exponent == 0)
|
normal_mask = ~(exponent == 0)
|
||||||
|
|
||||||
mantissa_scaled = torch.where(
|
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||||
|
|
||||||
|
sign *= torch.where(
|
||||||
normal_mask,
|
normal_mask,
|
||||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
)
|
|
||||||
mantissa_floor = mantissa_scaled.floor()
|
|
||||||
mantissa = torch.where(
|
|
||||||
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
|
||||||
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
|
||||||
mantissa_floor / (2**MANTISSA_BITS)
|
|
||||||
)
|
|
||||||
result = torch.where(
|
|
||||||
normal_mask,
|
|
||||||
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
|
|
||||||
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = torch.where(abs_x == 0, 0, result)
|
inf = torch.finfo(dtype)
|
||||||
return result.to(dtype=dtype)
|
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
||||||
|
return sign
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def stochastic_rounding(value, dtype):
|
def stochastic_rounding(value, dtype, seed=0):
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
return value.to(dtype=torch.float32)
|
return value.to(dtype=torch.float32)
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16:
|
||||||
@@ -54,6 +55,13 @@ def stochastic_rounding(value, dtype):
|
|||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return value.to(dtype=torch.bfloat16)
|
return value.to(dtype=torch.bfloat16)
|
||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
return manual_stochastic_round_to_float8(value, dtype)
|
generator = torch.Generator(device=value.device)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
for i in range(0, value.shape[0], slice_size):
|
||||||
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||||
|
return output
|
||||||
|
|
||||||
return value.to(dtype=dtype)
|
return value.to(dtype=dtype)
|
||||||
|
|||||||
@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
|||||||
return append_zero(sigmas)
|
return append_zero(sigmas)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||||
|
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||||
|
epsilon = 1e-5 # avoid log(0)
|
||||||
|
x = torch.linspace(0, 1, n, device=device)
|
||||||
|
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||||
|
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||||
|
sigmas = clamp(torch.exp(lmb))
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def to_d(x, sigma, denoised):
|
def to_d(x, sigma, denoised):
|
||||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||||
@@ -153,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||||
|
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
@@ -170,6 +183,29 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with Euler method steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||||
|
sigma_down = sigmas[i+1] * downstep_ratio
|
||||||
|
alpha_ip1 = 1 - sigmas[i+1]
|
||||||
|
alpha_down = 1 - sigma_down
|
||||||
|
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
|
# Euler method
|
||||||
|
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||||
|
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||||
|
if sigmas[i + 1] > 0 and eta > 0:
|
||||||
|
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||||
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||||
@@ -1069,7 +1105,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
d = to_d(x, sigma_hat, temp[0])
|
d = to_d(x, sigma_hat, temp[0])
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||||
dt = sigmas[i + 1] - sigma_hat
|
|
||||||
# Euler method
|
# Euler method
|
||||||
x = denoised + d * sigmas[i + 1]
|
x = denoised + d * sigmas[i + 1]
|
||||||
return x
|
return x
|
||||||
@@ -1096,8 +1131,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
d = to_d(x, sigmas[i], temp[0])
|
d = to_d(x, sigmas[i], temp[0])
|
||||||
# Euler method
|
# Euler method
|
||||||
dt = sigma_down - sigmas[i]
|
|
||||||
x = denoised + d * sigma_down
|
x = denoised + d * sigma_down
|
||||||
if sigmas[i + 1] > 0:
|
if sigmas[i + 1] > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
|
temp = [0]
|
||||||
|
def post_cfg_function(args):
|
||||||
|
temp[0] = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigma_down == 0:
|
||||||
|
# Euler method
|
||||||
|
d = to_d(x, sigmas[i], temp[0])
|
||||||
|
x = denoised + d * sigma_down
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2S)
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||||
|
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||||
|
r = 1 / 2
|
||||||
|
h = t_next - t
|
||||||
|
s = t + r * h
|
||||||
|
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||||
|
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||||
|
# Noise addition
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""DPM-Solver++(2M)."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
old_uncond_denoised = None
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||||
|
h = t_next - t
|
||||||
|
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||||
|
else:
|
||||||
|
h_last = t - t_fn(sigmas[i - 1])
|
||||||
|
r = h_last / h
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||||
|
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||||
|
old_uncond_denoised = uncond_denoised
|
||||||
|
return x
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ class LatentFormat:
|
|||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
|
latent_rgb_factors_bias = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.3920, 0.4054, 0.4549],
|
[ 0.3651, 0.4232, 0.4341],
|
||||||
[-0.2634, -0.0196, 0.0653],
|
[-0.2533, -0.0042, 0.1068],
|
||||||
[ 0.0568, 0.1687, -0.0755],
|
[ 0.1076, 0.1111, -0.0362],
|
||||||
[-0.3112, -0.2359, -0.2076]
|
[-0.3165, -0.2492, -0.2188]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
|
||||||
|
|
||||||
self.taesd_decoder_name = "taesdxl_decoder"
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
|
||||||
class SDXL_Playground_2_5(LatentFormat):
|
class SDXL_Playground_2_5(LatentFormat):
|
||||||
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
|
|||||||
self.scale_factor = 1.5305
|
self.scale_factor = 1.5305
|
||||||
self.shift_factor = 0.0609
|
self.shift_factor = 0.0609
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
[-0.0645, 0.0177, 0.1052],
|
[-0.0922, -0.0175, 0.0749],
|
||||||
[ 0.0028, 0.0312, 0.0650],
|
[ 0.0311, 0.0633, 0.0954],
|
||||||
[ 0.1848, 0.0762, 0.0360],
|
[ 0.1994, 0.0927, 0.0458],
|
||||||
[ 0.0944, 0.0360, 0.0889],
|
[ 0.0856, 0.0339, 0.0902],
|
||||||
[ 0.0897, 0.0506, -0.0364],
|
[ 0.0587, 0.0272, -0.0496],
|
||||||
[-0.0020, 0.1203, 0.0284],
|
[-0.0006, 0.1104, 0.0309],
|
||||||
[ 0.0855, 0.0118, 0.0283],
|
[ 0.0978, 0.0306, 0.0427],
|
||||||
[-0.0539, 0.0658, 0.1047],
|
[-0.0042, 0.1038, 0.1358],
|
||||||
[-0.0057, 0.0116, 0.0700],
|
[-0.0194, 0.0020, 0.0669],
|
||||||
[-0.0412, 0.0281, -0.0039],
|
[-0.0488, 0.0130, -0.0268],
|
||||||
[ 0.1106, 0.1171, 0.1220],
|
[ 0.0922, 0.0988, 0.0951],
|
||||||
[-0.0248, 0.0682, -0.0481],
|
[-0.0278, 0.0524, -0.0542],
|
||||||
[ 0.0815, 0.0846, 0.1207],
|
[ 0.0332, 0.0456, 0.0895],
|
||||||
[-0.0120, -0.0055, -0.0867],
|
[-0.0069, -0.0030, -0.0810],
|
||||||
[-0.0749, -0.0634, -0.0456],
|
[-0.0596, -0.0465, -0.0293],
|
||||||
[-0.1418, -0.1457, -0.1259]
|
[-0.1448, -0.1463, -0.1189]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
|
||||||
self.taesd_decoder_name = "taesd3_decoder"
|
self.taesd_decoder_name = "taesd3_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -146,23 +150,24 @@ class Flux(SD3):
|
|||||||
self.scale_factor = 0.3611
|
self.scale_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
self.latent_rgb_factors =[
|
self.latent_rgb_factors =[
|
||||||
[-0.0404, 0.0159, 0.0609],
|
[-0.0346, 0.0244, 0.0681],
|
||||||
[ 0.0043, 0.0298, 0.0850],
|
[ 0.0034, 0.0210, 0.0687],
|
||||||
[ 0.0328, -0.0749, -0.0503],
|
[ 0.0275, -0.0668, -0.0433],
|
||||||
[-0.0245, 0.0085, 0.0549],
|
[-0.0174, 0.0160, 0.0617],
|
||||||
[ 0.0966, 0.0894, 0.0530],
|
[ 0.0859, 0.0721, 0.0329],
|
||||||
[ 0.0035, 0.0399, 0.0123],
|
[ 0.0004, 0.0383, 0.0115],
|
||||||
[ 0.0583, 0.1184, 0.1262],
|
[ 0.0405, 0.0861, 0.0915],
|
||||||
[-0.0191, -0.0206, -0.0306],
|
[-0.0236, -0.0185, -0.0259],
|
||||||
[-0.0324, 0.0055, 0.1001],
|
[-0.0245, 0.0250, 0.1180],
|
||||||
[ 0.0955, 0.0659, -0.0545],
|
[ 0.1008, 0.0755, -0.0421],
|
||||||
[-0.0504, 0.0231, -0.0013],
|
[-0.0515, 0.0201, 0.0011],
|
||||||
[ 0.0500, -0.0008, -0.0088],
|
[ 0.0428, -0.0012, -0.0036],
|
||||||
[ 0.0982, 0.0941, 0.0976],
|
[ 0.0817, 0.0765, 0.0749],
|
||||||
[-0.1233, -0.0280, -0.0897],
|
[-0.1264, -0.0522, -0.1103],
|
||||||
[-0.0005, -0.0530, -0.0020],
|
[-0.0280, -0.0881, -0.0499],
|
||||||
[-0.1273, -0.0932, -0.0680]
|
[-0.1262, -0.0982, -0.0778]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||||
self.taesd_decoder_name = "taef1_decoder"
|
self.taesd_decoder_name = "taef1_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -170,3 +175,30 @@ class Flux(SD3):
|
|||||||
|
|
||||||
def process_out(self, latent):
|
def process_out(self, latent):
|
||||||
return (latent / self.scale_factor) + self.shift_factor
|
return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
|
||||||
|
class Mochi(LatentFormat):
|
||||||
|
latent_channels = 12
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.0
|
||||||
|
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
|
||||||
|
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
|
||||||
|
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
|
||||||
|
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
|
||||||
|
0.959253732819592, 0.8244560132752793, 0.917259975397747,
|
||||||
|
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||||
|
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
|
||||||
|
self.latent_rgb_factors = None #TODO
|
||||||
|
self.taesd_decoder_name = None #TODO
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
|
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
|
return latent * latents_std / self.scale_factor + latents_mean
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
@@ -6,3 +7,21 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
|
if weight is None:
|
||||||
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||||
|
else:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
if weight is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|||||||
205
comfy/ldm/flux/controlnet.py
Normal file
205
comfy/ldm/flux/controlnet.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
#modified to support different types of flux controlnets
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||||
|
MLPEmbedder, SingleStreamBlock,
|
||||||
|
timestep_embedding)
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class MistolineCondDownsamplBlock(nn.Module):
|
||||||
|
def __init__(self, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
class MistolineControlnetBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFlux(Flux):
|
||||||
|
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
self.main_model_double = 19
|
||||||
|
self.main_model_single = 38
|
||||||
|
|
||||||
|
self.mistoline = mistoline
|
||||||
|
# add ControlNet blocks
|
||||||
|
if self.mistoline:
|
||||||
|
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth):
|
||||||
|
self.controlnet_blocks.append(control_block())
|
||||||
|
|
||||||
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth_single_blocks):
|
||||||
|
self.controlnet_single_blocks.append(control_block())
|
||||||
|
|
||||||
|
self.num_union_modes = num_union_modes
|
||||||
|
self.controlnet_mode_embedder = None
|
||||||
|
if self.num_union_modes > 0:
|
||||||
|
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.latent_input = latent_input
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
else:
|
||||||
|
control_latent_channels *= 2 * 2 #patch size
|
||||||
|
|
||||||
|
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
if not self.latent_input:
|
||||||
|
if self.mistoline:
|
||||||
|
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.input_hint_block = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
control_type: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
|
||||||
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
|
img = img + controlnet_cond
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||||
|
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||||
|
txt = torch.cat([control_cond, txt], dim=1)
|
||||||
|
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
controlnet_double = ()
|
||||||
|
|
||||||
|
for i in range(len(self.double_blocks)):
|
||||||
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
controlnet_single = ()
|
||||||
|
|
||||||
|
for i in range(len(self.single_blocks)):
|
||||||
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||||
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||||
|
if self.latent_input:
|
||||||
|
out_input = ()
|
||||||
|
for x in controlnet_double:
|
||||||
|
out_input += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_input = (controlnet_double * repeat)
|
||||||
|
|
||||||
|
out = {"input": out_input[:self.main_model_double]}
|
||||||
|
if len(controlnet_single) > 0:
|
||||||
|
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||||
|
out_output = ()
|
||||||
|
if self.latent_input:
|
||||||
|
for x in controlnet_single:
|
||||||
|
out_output += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_output = (controlnet_single * repeat)
|
||||||
|
out["output"] = out_output[:self.main_model_single]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
|
patch_size = 2
|
||||||
|
if self.latent_input:
|
||||||
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
|
elif self.mistoline:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_cond_block(hint)
|
||||||
|
else:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_hint_block(hint)
|
||||||
|
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
||||||
MLPEmbedder, SingleStreamBlock,
|
|
||||||
timestep_embedding)
|
|
||||||
|
|
||||||
from .model import Flux
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
||||||
|
|
||||||
# add ControlNet blocks
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(self.params.depth):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
# controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
self.input_hint_block = nn.Sequential(
|
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
img: Tensor,
|
|
||||||
img_ids: Tensor,
|
|
||||||
controlnet_cond: Tensor,
|
|
||||||
txt: Tensor,
|
|
||||||
txt_ids: Tensor,
|
|
||||||
timesteps: Tensor,
|
|
||||||
y: Tensor,
|
|
||||||
guidance: Tensor = None,
|
|
||||||
) -> Tensor:
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
# running on sequences img
|
|
||||||
img = self.img_in(img)
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
|
||||||
img = img + controlnet_cond
|
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
||||||
if self.params.guidance_embed:
|
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
||||||
vec = vec + self.vector_in(y)
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
pe = self.pe_embedder(ids)
|
|
||||||
|
|
||||||
block_res_samples = ()
|
|
||||||
|
|
||||||
for block in self.double_blocks:
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
||||||
block_res_samples = block_res_samples + (img,)
|
|
||||||
|
|
||||||
controlnet_block_res_samples = ()
|
|
||||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
||||||
block_res_sample = controlnet_block(block_res_sample)
|
|
||||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
||||||
|
|
||||||
return {"input": (controlnet_block_res_samples * 10)[:19]}
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
|
||||||
hint = hint * 2.0 - 1.0
|
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
patch_size = 2
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
||||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
|
||||||
@@ -6,6 +6,7 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
x_dtype = x.dtype
|
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||||
x = x.float()
|
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
|
||||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class Flux(nn.Module):
|
|||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y)
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
@@ -151,8 +151,8 @@ class Flux(nn.Module):
|
|||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
|||||||
541
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
541
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
@@ -0,0 +1,541 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
from .layers import (
|
||||||
|
FeedForward,
|
||||||
|
PatchEmbed,
|
||||||
|
RMSNorm,
|
||||||
|
TimestepEmbedder,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .rope_mixed import (
|
||||||
|
compute_mixed_rotation,
|
||||||
|
create_position_matrix,
|
||||||
|
)
|
||||||
|
from .temporal_rope import apply_rotary_emb_qk_real
|
||||||
|
from .utils import (
|
||||||
|
AttentionPool,
|
||||||
|
modulate,
|
||||||
|
)
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
|
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||||
|
# Normalize and modulate
|
||||||
|
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
||||||
|
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||||
|
|
||||||
|
return x_modulated
|
||||||
|
|
||||||
|
|
||||||
|
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||||
|
# Apply tanh to gate
|
||||||
|
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||||
|
|
||||||
|
# Normalize and apply gated scaling
|
||||||
|
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
||||||
|
|
||||||
|
# Apply residual connection
|
||||||
|
output = x + x_normed
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
class AsymmetricAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_x: int,
|
||||||
|
dim_y: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
update_y: bool = True,
|
||||||
|
out_bias: bool = True,
|
||||||
|
attend_to_padding: bool = False,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim_x = dim_x
|
||||||
|
self.dim_y = dim_y
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim_x // num_heads
|
||||||
|
self.attn_drop = attn_drop
|
||||||
|
self.update_y = update_y
|
||||||
|
self.attend_to_padding = attend_to_padding
|
||||||
|
self.softmax_scale = softmax_scale
|
||||||
|
if dim_x % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input layers.
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
# Project text features to match visual features (dim_y -> dim_x)
|
||||||
|
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Query and key normalization for stability.
|
||||||
|
assert qk_norm
|
||||||
|
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Output layers. y features go back down from dim_x -> dim_y.
|
||||||
|
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||||
|
self.proj_y = (
|
||||||
|
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
|
||||||
|
if update_y
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor, # (B, N, dim_x)
|
||||||
|
y: torch.Tensor, # (B, L, dim_y)
|
||||||
|
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||||
|
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||||
|
crop_y,
|
||||||
|
**rope_rotation,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
rope_cos = rope_rotation.get("rope_cos")
|
||||||
|
rope_sin = rope_rotation.get("rope_sin")
|
||||||
|
# Pre-norm for visual features
|
||||||
|
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||||
|
|
||||||
|
# Process visual features
|
||||||
|
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||||
|
# assert qkv_x.dtype == torch.bfloat16
|
||||||
|
# qkv_x = all_to_all_collect_tokens(
|
||||||
|
# qkv_x, self.num_heads
|
||||||
|
# ) # (3, B, N, local_h, head_dim)
|
||||||
|
|
||||||
|
# Process text features
|
||||||
|
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||||
|
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||||
|
|
||||||
|
q_y = self.q_norm_y(q_y)
|
||||||
|
k_y = self.k_norm_y(k_y)
|
||||||
|
|
||||||
|
# Split qkv_x into q, k, v
|
||||||
|
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||||
|
q_x = self.q_norm_x(q_x)
|
||||||
|
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||||
|
k_x = self.k_norm_x(k_x)
|
||||||
|
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||||
|
|
||||||
|
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
|
||||||
|
xy = optimized_attention(q,
|
||||||
|
k,
|
||||||
|
v, self.num_heads, skip_reshape=True)
|
||||||
|
|
||||||
|
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||||
|
x = self.proj_x(x)
|
||||||
|
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||||
|
o[:, :y.shape[1]] = y
|
||||||
|
|
||||||
|
y = self.proj_y(o)
|
||||||
|
# print("ox", x)
|
||||||
|
# print("oy", y)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class AsymmetricJointBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size_x: int,
|
||||||
|
hidden_size_y: int,
|
||||||
|
num_heads: int,
|
||||||
|
*,
|
||||||
|
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||||
|
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||||
|
update_y: bool = True, # Whether to update text tokens in this block.
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.update_y = update_y
|
||||||
|
self.hidden_size_x = hidden_size_x
|
||||||
|
self.hidden_size_y = hidden_size_y
|
||||||
|
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
|
||||||
|
if self.update_y:
|
||||||
|
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Self-attention:
|
||||||
|
self.attn = AsymmetricAttention(
|
||||||
|
hidden_size_x,
|
||||||
|
hidden_size_y,
|
||||||
|
num_heads=num_heads,
|
||||||
|
update_y=update_y,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MLP.
|
||||||
|
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||||
|
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||||
|
self.mlp_x = FeedForward(
|
||||||
|
in_features=hidden_size_x,
|
||||||
|
hidden_size=mlp_hidden_dim_x,
|
||||||
|
multiple_of=256,
|
||||||
|
ffn_dim_multiplier=None,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MLP for text not needed in last block.
|
||||||
|
if self.update_y:
|
||||||
|
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
||||||
|
self.mlp_y = FeedForward(
|
||||||
|
in_features=hidden_size_y,
|
||||||
|
hidden_size=mlp_hidden_dim_y,
|
||||||
|
multiple_of=256,
|
||||||
|
ffn_dim_multiplier=None,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
**attn_kwargs,
|
||||||
|
):
|
||||||
|
"""Forward pass of a block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, N, dim) tensor of visual tokens
|
||||||
|
c: (B, dim) tensor of conditioned features
|
||||||
|
y: (B, L, dim) tensor of text tokens
|
||||||
|
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: (B, N, dim) tensor of visual tokens after block
|
||||||
|
y: (B, L, dim) tensor of text tokens after block
|
||||||
|
"""
|
||||||
|
N = x.size(1)
|
||||||
|
|
||||||
|
c = F.silu(c)
|
||||||
|
mod_x = self.mod_x(c)
|
||||||
|
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
||||||
|
|
||||||
|
mod_y = self.mod_y(c)
|
||||||
|
if self.update_y:
|
||||||
|
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||||
|
else:
|
||||||
|
scale_msa_y = mod_y
|
||||||
|
|
||||||
|
# Self-attention block.
|
||||||
|
x_attn, y_attn = self.attn(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
scale_x=scale_msa_x,
|
||||||
|
scale_y=scale_msa_y,
|
||||||
|
**attn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert x_attn.size(1) == N
|
||||||
|
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||||
|
if self.update_y:
|
||||||
|
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||||
|
|
||||||
|
# MLP block.
|
||||||
|
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||||
|
if self.update_y:
|
||||||
|
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def ff_block_x(self, x, scale_x, gate_x):
|
||||||
|
x_mod = modulated_rmsnorm(x, scale_x)
|
||||||
|
x_res = self.mlp_x(x_mod)
|
||||||
|
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||||
|
return x
|
||||||
|
|
||||||
|
def ff_block_y(self, y, scale_y, gate_y):
|
||||||
|
y_mod = modulated_rmsnorm(y, scale_y)
|
||||||
|
y_res = self.mlp_y(y_mod)
|
||||||
|
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
patch_size,
|
||||||
|
out_channels,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
|
||||||
|
self.linear = operations.Linear(
|
||||||
|
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
c = F.silu(c)
|
||||||
|
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AsymmDiTJoint(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion model with a Transformer backbone.
|
||||||
|
|
||||||
|
Ingests text embeddings instead of a label.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
patch_size=2,
|
||||||
|
in_channels=4,
|
||||||
|
hidden_size_x=1152,
|
||||||
|
hidden_size_y=1152,
|
||||||
|
depth=48,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio_x=8.0,
|
||||||
|
mlp_ratio_y=4.0,
|
||||||
|
use_t5: bool = False,
|
||||||
|
t5_feat_dim: int = 4096,
|
||||||
|
t5_token_length: int = 256,
|
||||||
|
learn_sigma=True,
|
||||||
|
patch_embed_bias: bool = True,
|
||||||
|
timestep_mlp_bias: bool = True,
|
||||||
|
attend_to_padding: bool = False,
|
||||||
|
timestep_scale: Optional[float] = None,
|
||||||
|
use_extended_posenc: bool = False,
|
||||||
|
posenc_preserve_area: bool = False,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
image_model=None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.learn_sigma = learn_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size_x = hidden_size_x
|
||||||
|
self.hidden_size_y = hidden_size_y
|
||||||
|
self.head_dim = (
|
||||||
|
hidden_size_x // num_heads
|
||||||
|
) # Head dimension and count is determined by visual.
|
||||||
|
self.attend_to_padding = attend_to_padding
|
||||||
|
self.use_extended_posenc = use_extended_posenc
|
||||||
|
self.posenc_preserve_area = posenc_preserve_area
|
||||||
|
self.use_t5 = use_t5
|
||||||
|
self.t5_token_length = t5_token_length
|
||||||
|
self.t5_feat_dim = t5_feat_dim
|
||||||
|
self.rope_theta = (
|
||||||
|
rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
||||||
|
)
|
||||||
|
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_channels,
|
||||||
|
embed_dim=hidden_size_x,
|
||||||
|
bias=patch_embed_bias,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
# Conditionings
|
||||||
|
# Timestep
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_t5:
|
||||||
|
# Caption Pooling (T5)
|
||||||
|
self.t5_y_embedder = AttentionPool(
|
||||||
|
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dense Embedding Projection (T5)
|
||||||
|
self.t5_yproj = operations.Linear(
|
||||||
|
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize pos_frequencies as an empty parameter.
|
||||||
|
self.pos_frequencies = nn.Parameter(
|
||||||
|
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not self.attend_to_padding
|
||||||
|
|
||||||
|
# for depth 48:
|
||||||
|
# b = 0: AsymmetricJointBlock, update_y=True
|
||||||
|
# b = 1: AsymmetricJointBlock, update_y=True
|
||||||
|
# ...
|
||||||
|
# b = 46: AsymmetricJointBlock, update_y=True
|
||||||
|
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
||||||
|
blocks = []
|
||||||
|
for b in range(depth):
|
||||||
|
# Joint multi-modal block
|
||||||
|
update_y = b < depth - 1
|
||||||
|
block = AsymmetricJointBlock(
|
||||||
|
hidden_size_x,
|
||||||
|
hidden_size_y,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio_x=mlp_ratio_x,
|
||||||
|
mlp_ratio_y=mlp_ratio_y,
|
||||||
|
update_y=update_y,
|
||||||
|
attend_to_padding=attend_to_padding,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(block)
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C=12, T, H, W) tensor of visual tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
||||||
|
"""
|
||||||
|
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sigma: torch.Tensor,
|
||||||
|
t5_feat: torch.Tensor,
|
||||||
|
t5_mask: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Prepare input and conditioning embeddings."""
|
||||||
|
# Visual patch embeddings with positional encoding.
|
||||||
|
T, H, W = x.shape[-3:]
|
||||||
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
|
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||||
|
assert x.ndim == 3
|
||||||
|
B = x.size(0)
|
||||||
|
|
||||||
|
|
||||||
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
|
N = T * pH * pW
|
||||||
|
assert x.size(1) == N
|
||||||
|
pos = create_position_matrix(
|
||||||
|
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||||
|
) # (N, 3)
|
||||||
|
rope_cos, rope_sin = compute_mixed_rotation(
|
||||||
|
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||||
|
) # Each are (N, num_heads, dim // 2)
|
||||||
|
|
||||||
|
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||||
|
|
||||||
|
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||||
|
|
||||||
|
c = c_t + t5_y_pool
|
||||||
|
|
||||||
|
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
||||||
|
|
||||||
|
return x, c, y_feat, rope_cos, rope_sin
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: List[torch.Tensor],
|
||||||
|
attention_mask: List[torch.Tensor],
|
||||||
|
num_tokens=256,
|
||||||
|
packed_indices: Dict[str, torch.Tensor] = None,
|
||||||
|
rope_cos: torch.Tensor = None,
|
||||||
|
rope_sin: torch.Tensor = None,
|
||||||
|
control=None, **kwargs
|
||||||
|
):
|
||||||
|
y_feat = context
|
||||||
|
y_mask = attention_mask
|
||||||
|
sigma = timestep
|
||||||
|
"""Forward pass of DiT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
|
sigma: (B,) tensor of noise standard deviations
|
||||||
|
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||||
|
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||||
|
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||||
|
"""
|
||||||
|
B, _, T, H, W = x.shape
|
||||||
|
|
||||||
|
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||||
|
x, sigma, y_feat, y_mask
|
||||||
|
)
|
||||||
|
del y_mask
|
||||||
|
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
x, y_feat = block(
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
y_feat,
|
||||||
|
rope_cos=rope_cos,
|
||||||
|
rope_sin=rope_sin,
|
||||||
|
crop_y=num_tokens,
|
||||||
|
) # (B, M, D), (B, L, D)
|
||||||
|
del y_feat # Final layers don't use dense text features.
|
||||||
|
|
||||||
|
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||||
|
T=T,
|
||||||
|
hp=H // self.patch_size,
|
||||||
|
wp=W // self.patch_size,
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
c=self.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
return -x
|
||||||
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
import math
|
||||||
|
from itertools import repeat
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
# From PyTorch internals
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||||
|
return tuple(x)
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
to_2tuple = _ntuple(2)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
frequency_embedding_size: int = 256,
|
||||||
|
*,
|
||||||
|
bias: bool = True,
|
||||||
|
timestep_scale: Optional[float] = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.timestep_scale = timestep_scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||||
|
freqs.mul_(-math.log(max_period) / half).exp_()
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t, out_dtype):
|
||||||
|
if self.timestep_scale is not None:
|
||||||
|
t = t * self.timestep_scale
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_size: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# keep parameter count and computation constant compared to standard FFN
|
||||||
|
hidden_size = int(2 * hidden_size / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
||||||
|
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||||
|
x = self.w2(F.silu(x) * gate)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
norm_layer: Optional[Callable] = None,
|
||||||
|
flatten: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = to_2tuple(patch_size)
|
||||||
|
self.flatten = flatten
|
||||||
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
|
||||||
|
self.proj = operations.Conv2d(
|
||||||
|
in_chans,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
assert norm_layer is None
|
||||||
|
self.norm = (
|
||||||
|
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, _C, T, H, W = x.shape
|
||||||
|
if not self.dynamic_img_pad:
|
||||||
|
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||||
|
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||||
|
else:
|
||||||
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||||
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||||
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
|
||||||
|
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
# Flatten temporal and spatial dimensions.
|
||||||
|
if not self.flatten:
|
||||||
|
raise NotImplementedError("Must flatten output.")
|
||||||
|
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
|
||||||
|
# import functools
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def centers(start: float, stop, num, dtype=None, device=None):
|
||||||
|
"""linspace through bin centers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start (float): Start of the range.
|
||||||
|
stop (float): End of the range.
|
||||||
|
num (int): Number of points.
|
||||||
|
dtype (torch.dtype): Data type of the points.
|
||||||
|
device (torch.device): Device of the points.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
centers (Tensor): Centers of the bins. Shape: (num,).
|
||||||
|
"""
|
||||||
|
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
||||||
|
return (edges[:-1] + edges[1:]) / 2
|
||||||
|
|
||||||
|
|
||||||
|
# @functools.lru_cache(maxsize=1)
|
||||||
|
def create_position_matrix(
|
||||||
|
T: int,
|
||||||
|
pH: int,
|
||||||
|
pW: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
target_area: float = 36864,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
T: int - Temporal dimension
|
||||||
|
pH: int - Height dimension after patchify
|
||||||
|
pW: int - Width dimension after patchify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pos: [T * pH * pW, 3] - position matrix
|
||||||
|
"""
|
||||||
|
# Create 1D tensors for each dimension
|
||||||
|
t = torch.arange(T, dtype=dtype)
|
||||||
|
|
||||||
|
# Positionally interpolate to area 36864.
|
||||||
|
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
||||||
|
# This automatically scales rope positions when the resolution changes.
|
||||||
|
# We use a large target area so the model is more sensitive
|
||||||
|
# to changes in the learned pos_frequencies matrix.
|
||||||
|
scale = math.sqrt(target_area / (pW * pH))
|
||||||
|
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
||||||
|
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
||||||
|
|
||||||
|
# Use meshgrid to create 3D grids
|
||||||
|
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||||
|
|
||||||
|
# Stack and reshape the grids.
|
||||||
|
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||||
|
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||||
|
pos = pos.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mixed_rotation(
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
pos: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
||||||
|
pos: [N, 3] - position of each token
|
||||||
|
num_heads: int
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||||
|
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||||
|
"""
|
||||||
|
assert freqs.ndim == 3
|
||||||
|
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
||||||
|
freqs_cos = torch.cos(freqs_sum)
|
||||||
|
freqs_sin = torch.sin(freqs_sum)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
|
||||||
|
# Based on Llama3 Implementation.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb_qk_real(
|
||||||
|
xqk: torch.Tensor,
|
||||||
|
freqs_cos: torch.Tensor,
|
||||||
|
freqs_sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
||||||
|
Can be either just query or just key, or both stacked along some batch or * dim.
|
||||||
|
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
||||||
|
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||||
|
"""
|
||||||
|
# Split the last dimension into even and odd parts
|
||||||
|
xqk_even = xqk[..., 0::2]
|
||||||
|
xqk_odd = xqk[..., 1::2]
|
||||||
|
|
||||||
|
# Apply rotation
|
||||||
|
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
||||||
|
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||||
|
|
||||||
|
# Interleave the results back into the original shape
|
||||||
|
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||||
|
return out
|
||||||
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pool tokens in x using mask.
|
||||||
|
|
||||||
|
NOTE: We assume x does not require gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, L, D) tensor of tokens.
|
||||||
|
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pooled: (B, D) tensor of pooled tokens.
|
||||||
|
"""
|
||||||
|
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||||
|
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||||
|
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||||
|
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||||
|
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
output_dim: int = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
spatial_dim (int): Number of tokens in sequence length.
|
||||||
|
embed_dim (int): Dimensionality of input tokens.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
|
||||||
|
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||||
|
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||||
|
|
||||||
|
NOTE: We assume x does not require gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||||
|
"""
|
||||||
|
D = x.size(2)
|
||||||
|
|
||||||
|
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||||
|
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||||
|
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||||
|
|
||||||
|
# Average non-padding token features. These will be used as the query.
|
||||||
|
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||||
|
|
||||||
|
# Concat pooled features to input sequence.
|
||||||
|
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||||
|
|
||||||
|
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||||
|
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||||
|
q = self.to_q(x[:, 0]) # (B, D)
|
||||||
|
|
||||||
|
# Extract heads.
|
||||||
|
head_dim = D // self.num_heads
|
||||||
|
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||||
|
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||||
|
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||||
|
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||||
|
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
|
# Compute attention.
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||||
|
) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
|
# Concatenate heads and run output.
|
||||||
|
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||||
|
x = self.to_out(x)
|
||||||
|
return x
|
||||||
711
comfy/ldm/genmo/vae/model.py
Normal file
711
comfy/ldm/genmo/vae/model.py
Normal file
@@ -0,0 +1,711 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
from functools import partial
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||||
|
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(t, length=1):
|
||||||
|
return t if isinstance(t, tuple) else ((t,) * length)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNormSpatial(ops.GroupNorm):
|
||||||
|
"""
|
||||||
|
GroupNorm applied per-frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||||
|
# Run group norm in chunks.
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
for b in range(0, B * T, chunk_size):
|
||||||
|
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
||||||
|
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||||
|
|
||||||
|
class PConv3d(ops.Conv3d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
stride: Union[int, Tuple[int, int, int]],
|
||||||
|
causal: bool = True,
|
||||||
|
context_parallel: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.causal = causal
|
||||||
|
self.context_parallel = context_parallel
|
||||||
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
|
stride = cast_tuple(stride, 3)
|
||||||
|
height_pad = (kernel_size[1] - 1) // 2
|
||||||
|
width_pad = (kernel_size[2] - 1) // 2
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=(1, 1, 1),
|
||||||
|
padding=(0, height_pad, width_pad),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# Compute padding amounts.
|
||||||
|
context_size = self.kernel_size[0] - 1
|
||||||
|
if self.causal:
|
||||||
|
pad_front = context_size
|
||||||
|
pad_back = 0
|
||||||
|
else:
|
||||||
|
pad_front = context_size // 2
|
||||||
|
pad_back = context_size - pad_front
|
||||||
|
|
||||||
|
# Apply padding.
|
||||||
|
assert self.padding_mode == "replicate" # DEBUG
|
||||||
|
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||||
|
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1x1(ops.Linear):
|
||||||
|
"""*1x1 Conv implemented with a linear layer."""
|
||||||
|
|
||||||
|
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
|
||||||
|
super().__init__(in_features, out_features, *args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, *] or [B, *, C].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
|
||||||
|
"""
|
||||||
|
x = x.movedim(1, -1)
|
||||||
|
x = super().forward(x)
|
||||||
|
x = x.movedim(-1, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DepthToSpaceTime(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temporal_expansion: int,
|
||||||
|
spatial_expansion: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.temporal_expansion = temporal_expansion
|
||||||
|
self.spatial_expansion = spatial_expansion
|
||||||
|
|
||||||
|
# When printed, this module should show the temporal and spatial expansion factors.
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
|
||||||
|
"""
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
|
||||||
|
st=self.temporal_expansion,
|
||||||
|
sh=self.spatial_expansion,
|
||||||
|
sw=self.spatial_expansion,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cp_rank, _ = cp.get_cp_rank_size()
|
||||||
|
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||||
|
# Drop the first self.temporal_expansion - 1 frames.
|
||||||
|
# This is because we always want the 3x3x3 conv filter to only apply
|
||||||
|
# to the first frame, and the first frame doesn't need to be repeated.
|
||||||
|
assert all(x.shape)
|
||||||
|
x = x[:, :, self.temporal_expansion - 1 :]
|
||||||
|
assert all(x.shape)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def norm_fn(
|
||||||
|
in_channels: int,
|
||||||
|
affine: bool = True,
|
||||||
|
):
|
||||||
|
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
"""Residual block that preserves the spatial dimensions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
*,
|
||||||
|
affine: bool = True,
|
||||||
|
attn_block: Optional[nn.Module] = None,
|
||||||
|
causal: bool = True,
|
||||||
|
prune_bottleneck: bool = False,
|
||||||
|
padding_mode: str,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
assert causal
|
||||||
|
self.stack = nn.Sequential(
|
||||||
|
norm_fn(channels, affine=affine),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
PConv3d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
stride=(1, 1, 1),
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
bias=bias,
|
||||||
|
causal=causal,
|
||||||
|
),
|
||||||
|
norm_fn(channels, affine=affine),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
PConv3d(
|
||||||
|
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
stride=(1, 1, 1),
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
bias=bias,
|
||||||
|
causal=causal,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn_block = attn_block if attn_block else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
x = self.stack(x)
|
||||||
|
x = x + residual
|
||||||
|
del residual
|
||||||
|
|
||||||
|
return self.attn_block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
head_dim: int = 32,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
out_bias: bool = True,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_heads = dim // head_dim
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||||
|
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute temporal self-attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
chunk_size: Chunk size for large tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Output tensor. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
B, _, T, H, W = x.shape
|
||||||
|
|
||||||
|
if T == 1:
|
||||||
|
# No attention for single frame.
|
||||||
|
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||||
|
x = self.out(x)
|
||||||
|
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||||
|
|
||||||
|
# 1D temporal attention.
|
||||||
|
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
|
||||||
|
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||||
|
# Output: x with shape [B, num_heads, t, head_dim]
|
||||||
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
|
||||||
|
|
||||||
|
if self.qk_norm:
|
||||||
|
q = F.normalize(q, p=2, dim=-1)
|
||||||
|
k = F.normalize(k, p=2, dim=-1)
|
||||||
|
|
||||||
|
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||||
|
|
||||||
|
assert x.size(0) == q.size(0)
|
||||||
|
|
||||||
|
x = self.out(x)
|
||||||
|
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
**attn_kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm = norm_fn(dim)
|
||||||
|
self.attn = Attention(dim, **attn_kwargs)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + self.attn(self.norm(x))
|
||||||
|
|
||||||
|
|
||||||
|
class CausalUpsampleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
*,
|
||||||
|
temporal_expansion: int = 2,
|
||||||
|
spatial_expansion: int = 2,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
blocks.append(block_fn(in_channels, **block_kwargs))
|
||||||
|
self.blocks = nn.Sequential(*blocks)
|
||||||
|
|
||||||
|
self.temporal_expansion = temporal_expansion
|
||||||
|
self.spatial_expansion = spatial_expansion
|
||||||
|
|
||||||
|
# Change channels in the final convolution layer.
|
||||||
|
self.proj = Conv1x1(
|
||||||
|
in_channels,
|
||||||
|
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.d2st = DepthToSpaceTime(
|
||||||
|
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.d2st(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||||
|
attn_block = AttentionBlock(channels) if has_attention else None
|
||||||
|
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DownsampleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_res_blocks,
|
||||||
|
*,
|
||||||
|
temporal_reduction=2,
|
||||||
|
spatial_reduction=2,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Downsample block for the VAE encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels.
|
||||||
|
out_channels: Number of output channels.
|
||||||
|
num_res_blocks: Number of residual blocks.
|
||||||
|
temporal_reduction: Temporal reduction factor.
|
||||||
|
spatial_reduction: Spatial reduction factor.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
# Change the channel count in the strided convolution.
|
||||||
|
# This lets the ResBlock have uniform channel count,
|
||||||
|
# as in ConvNeXt.
|
||||||
|
assert in_channels != out_channels
|
||||||
|
layers.append(
|
||||||
|
PConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
# First layer in each block always uses replicate padding
|
||||||
|
padding_mode="replicate",
|
||||||
|
bias=block_kwargs["bias"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers.append(block_fn(out_channels, **block_kwargs))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||||
|
num_freqs = (stop - start) // step
|
||||||
|
assert inputs.ndim == 5
|
||||||
|
C = inputs.size(1)
|
||||||
|
|
||||||
|
# Create Base 2 Fourier features.
|
||||||
|
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
|
||||||
|
assert num_freqs == len(freqs)
|
||||||
|
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
|
||||||
|
C = inputs.shape[1]
|
||||||
|
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
|
||||||
|
|
||||||
|
# Interleaved repeat of input channels to match w.
|
||||||
|
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||||
|
# Scale channels by frequency.
|
||||||
|
h = w * h
|
||||||
|
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
inputs,
|
||||||
|
torch.sin(h),
|
||||||
|
torch.cos(h),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.start = start
|
||||||
|
self.stop = stop
|
||||||
|
self.step = step
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""Add Fourier features to inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Input tensor. Shape: [B, C, T, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
|
||||||
|
"""
|
||||||
|
return add_fourier_features(inputs, self.start, self.stop, self.step)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
out_channels: int = 3,
|
||||||
|
latent_dim: int,
|
||||||
|
base_channels: int,
|
||||||
|
channel_multipliers: List[int],
|
||||||
|
num_res_blocks: List[int],
|
||||||
|
temporal_expansions: Optional[List[int]] = None,
|
||||||
|
spatial_expansions: Optional[List[int]] = None,
|
||||||
|
has_attention: List[bool],
|
||||||
|
output_norm: bool = True,
|
||||||
|
nonlinearity: str = "silu",
|
||||||
|
output_nonlinearity: str = "silu",
|
||||||
|
causal: bool = True,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_channels = latent_dim
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.channel_multipliers = channel_multipliers
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.output_nonlinearity = output_nonlinearity
|
||||||
|
assert nonlinearity == "silu"
|
||||||
|
assert causal
|
||||||
|
|
||||||
|
ch = [mult * base_channels for mult in channel_multipliers]
|
||||||
|
self.num_up_blocks = len(ch) - 1
|
||||||
|
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
|
||||||
|
first_block = [
|
||||||
|
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||||
|
] # Input layer.
|
||||||
|
# First set of blocks preserve channel count.
|
||||||
|
for _ in range(num_res_blocks[-1]):
|
||||||
|
first_block.append(
|
||||||
|
block_fn(
|
||||||
|
ch[-1],
|
||||||
|
has_attention=has_attention[-1],
|
||||||
|
causal=causal,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
blocks.append(nn.Sequential(*first_block))
|
||||||
|
|
||||||
|
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
|
||||||
|
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
|
||||||
|
|
||||||
|
upsample_block_fn = CausalUpsampleBlock
|
||||||
|
|
||||||
|
for i in range(self.num_up_blocks):
|
||||||
|
block = upsample_block_fn(
|
||||||
|
ch[-i - 1],
|
||||||
|
ch[-i - 2],
|
||||||
|
num_res_blocks=num_res_blocks[-i - 2],
|
||||||
|
has_attention=has_attention[-i - 2],
|
||||||
|
temporal_expansion=temporal_expansions[-i - 1],
|
||||||
|
spatial_expansion=spatial_expansions[-i - 1],
|
||||||
|
causal=causal,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
|
assert not output_norm
|
||||||
|
|
||||||
|
# Last block. Preserve channel count.
|
||||||
|
last_block = []
|
||||||
|
for _ in range(num_res_blocks[0]):
|
||||||
|
last_block.append(
|
||||||
|
block_fn(
|
||||||
|
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
blocks.append(nn.Sequential(*last_block))
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
self.output_proj = Conv1x1(ch[0], out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
|
||||||
|
T + 1 = (t - 1) * 4.
|
||||||
|
H = h * 16, W = w * 16.
|
||||||
|
"""
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.output_nonlinearity == "silu":
|
||||||
|
x = F.silu(x, inplace=not self.training)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not self.output_nonlinearity
|
||||||
|
) # StyleGAN3 omits the to-RGB nonlinearity.
|
||||||
|
|
||||||
|
return self.output_proj(x).contiguous()
|
||||||
|
|
||||||
|
class LatentDistribution:
|
||||||
|
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||||
|
"""Initialize latent distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||||
|
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
assert mean.shape == logvar.shape
|
||||||
|
self.mean = mean
|
||||||
|
self.logvar = logvar
|
||||||
|
|
||||||
|
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||||
|
if temperature == 0.0:
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||||
|
else:
|
||||||
|
assert noise.device == self.mean.device
|
||||||
|
noise = noise.to(self.mean.dtype)
|
||||||
|
|
||||||
|
if temperature != 1.0:
|
||||||
|
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||||
|
|
||||||
|
# Just Gaussian sample with no scaling of variance.
|
||||||
|
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels: int,
|
||||||
|
base_channels: int,
|
||||||
|
channel_multipliers: List[int],
|
||||||
|
num_res_blocks: List[int],
|
||||||
|
latent_dim: int,
|
||||||
|
temporal_reductions: List[int],
|
||||||
|
spatial_reductions: List[int],
|
||||||
|
prune_bottlenecks: List[bool],
|
||||||
|
has_attentions: List[bool],
|
||||||
|
affine: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
input_is_conv_1x1: bool = False,
|
||||||
|
padding_mode: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.temporal_reductions = temporal_reductions
|
||||||
|
self.spatial_reductions = spatial_reductions
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.channel_multipliers = channel_multipliers
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
self.fourier_features = FourierFeatures()
|
||||||
|
ch = [mult * base_channels for mult in channel_multipliers]
|
||||||
|
num_down_blocks = len(ch) - 1
|
||||||
|
assert len(num_res_blocks) == num_down_blocks + 2
|
||||||
|
|
||||||
|
layers = (
|
||||||
|
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||||
|
if not input_is_conv_1x1
|
||||||
|
else [Conv1x1(in_channels, ch[0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||||
|
assert len(has_attentions) == num_down_blocks + 2
|
||||||
|
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks[0]):
|
||||||
|
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||||
|
prune_bottlenecks = prune_bottlenecks[1:]
|
||||||
|
has_attentions = has_attentions[1:]
|
||||||
|
|
||||||
|
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
layer = DownsampleBlock(
|
||||||
|
ch[i],
|
||||||
|
ch[i + 1],
|
||||||
|
num_res_blocks=num_res_blocks[i + 1],
|
||||||
|
temporal_reduction=temporal_reductions[i],
|
||||||
|
spatial_reduction=spatial_reductions[i],
|
||||||
|
prune_bottleneck=prune_bottlenecks[i],
|
||||||
|
has_attention=has_attentions[i],
|
||||||
|
affine=affine,
|
||||||
|
bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
|
# Additional blocks.
|
||||||
|
for _ in range(num_res_blocks[-1]):
|
||||||
|
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
# Output layers.
|
||||||
|
self.output_norm = norm_fn(ch[-1])
|
||||||
|
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temporal_downsample(self):
|
||||||
|
return math.prod(self.temporal_reductions)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spatial_downsample(self):
|
||||||
|
return math.prod(self.spatial_reductions)
|
||||||
|
|
||||||
|
def forward(self, x) -> LatentDistribution:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||||
|
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||||
|
logvar: Shape: [B, latent_dim, t, h, w].
|
||||||
|
"""
|
||||||
|
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||||
|
x = self.fourier_features(x)
|
||||||
|
|
||||||
|
x = self.layers(x)
|
||||||
|
|
||||||
|
x = self.output_norm(x)
|
||||||
|
x = F.silu(x, inplace=True)
|
||||||
|
x = self.output_proj(x)
|
||||||
|
|
||||||
|
means, logvar = torch.chunk(x, 2, dim=1)
|
||||||
|
|
||||||
|
assert means.ndim == 5
|
||||||
|
assert logvar.shape == means.shape
|
||||||
|
assert means.size(1) == self.latent_dim
|
||||||
|
|
||||||
|
return LatentDistribution(means, logvar)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoVAE(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=15,
|
||||||
|
base_channels=64,
|
||||||
|
channel_multipliers=[1, 2, 4, 6],
|
||||||
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
|
latent_dim=12,
|
||||||
|
temporal_reductions=[1, 2, 3],
|
||||||
|
spatial_reductions=[2, 2, 2],
|
||||||
|
prune_bottlenecks=[False, False, False, False, False],
|
||||||
|
has_attentions=[False, True, True, True, True],
|
||||||
|
affine=True,
|
||||||
|
bias=True,
|
||||||
|
input_is_conv_1x1=True,
|
||||||
|
padding_mode="replicate"
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
out_channels=3,
|
||||||
|
base_channels=128,
|
||||||
|
channel_multipliers=[1, 2, 4, 6],
|
||||||
|
temporal_expansions=[1, 2, 3],
|
||||||
|
spatial_expansions=[2, 2, 2],
|
||||||
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
|
latent_dim=12,
|
||||||
|
has_attention=[False, False, False, False, False],
|
||||||
|
padding_mode="replicate",
|
||||||
|
output_norm=False,
|
||||||
|
nonlinearity="silu",
|
||||||
|
output_nonlinearity="silu",
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x).mode()
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return self.decoder(x)
|
||||||
@@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
if layer > self.depth // 2:
|
if layer > self.depth // 2:
|
||||||
if controls is not None:
|
if controls is not None:
|
||||||
skip = skips.pop() + controls.pop()
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
|||||||
@@ -393,6 +393,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
||||||
|
SDP_BATCH_LIMIT = 2**15
|
||||||
|
else:
|
||||||
|
#TODO: other GPUs ?
|
||||||
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
@@ -404,10 +411,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
if SDP_BATCH_LIMIT >= q.shape[0]:
|
||||||
out = (
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out = (
|
||||||
)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||||
|
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
|
||||||
|
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .. import attention
|
from ..attention import optimized_attention
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from .util import timestep_embedding
|
from .util import timestep_embedding
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
@@ -97,7 +97,7 @@ class PatchEmbed(nn.Module):
|
|||||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
# B, C, H, W = x.shape
|
||||||
# if self.img_size is not None:
|
# if self.img_size is not None:
|
||||||
# if self.strict_img_size:
|
# if self.strict_img_size:
|
||||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||||
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
|
|||||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||||
return qkv[0], qkv[1], qkv[2]
|
return qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
def optimized_attention(qkv, num_heads):
|
|
||||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||||
@@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
qkv = self.pre_attention(x)
|
q, k, v = self.pre_attention(x)
|
||||||
x = optimized_attention(
|
x = optimized_attention(
|
||||||
qkv, num_heads=self.num_heads
|
q, k, v, heads=self.num_heads
|
||||||
)
|
)
|
||||||
x = self.post_attention(x)
|
x = self.post_attention(x)
|
||||||
return x
|
return x
|
||||||
@@ -355,29 +353,9 @@ class RMSNorm(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
"""
|
|
||||||
Apply the RMSNorm normalization to the input tensor.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
Forward pass through the RMSNorm layer.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The output tensor after applying RMSNorm.
|
|
||||||
"""
|
|
||||||
x = self._norm(x)
|
|
||||||
if self.learnable_scale:
|
|
||||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUFeedForward(nn.Module):
|
class SwiGLUFeedForward(nn.Module):
|
||||||
@@ -437,6 +415,7 @@ class DismantledBlock(nn.Module):
|
|||||||
scale_mod_only: bool = False,
|
scale_mod_only: bool = False,
|
||||||
swiglu: bool = False,
|
swiglu: bool = False,
|
||||||
qk_norm: Optional[str] = None,
|
qk_norm: Optional[str] = None,
|
||||||
|
x_block_self_attn: bool = False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -460,6 +439,24 @@ class DismantledBlock(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
|
if x_block_self_attn:
|
||||||
|
assert not pre_only
|
||||||
|
assert not scale_mod_only
|
||||||
|
self.x_block_self_attn = True
|
||||||
|
self.attn2 = SelfAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
pre_only=False,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rmsnorm=rmsnorm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.x_block_self_attn = False
|
||||||
if not pre_only:
|
if not pre_only:
|
||||||
if not rmsnorm:
|
if not rmsnorm:
|
||||||
self.norm2 = operations.LayerNorm(
|
self.norm2 = operations.LayerNorm(
|
||||||
@@ -486,7 +483,11 @@ class DismantledBlock(nn.Module):
|
|||||||
multiple_of=256,
|
multiple_of=256,
|
||||||
)
|
)
|
||||||
self.scale_mod_only = scale_mod_only
|
self.scale_mod_only = scale_mod_only
|
||||||
if not scale_mod_only:
|
if x_block_self_attn:
|
||||||
|
assert not pre_only
|
||||||
|
assert not scale_mod_only
|
||||||
|
n_mods = 9
|
||||||
|
elif not scale_mod_only:
|
||||||
n_mods = 6 if not pre_only else 2
|
n_mods = 6 if not pre_only else 2
|
||||||
else:
|
else:
|
||||||
n_mods = 4 if not pre_only else 1
|
n_mods = 4 if not pre_only else 1
|
||||||
@@ -547,14 +548,64 @@ class DismantledBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert self.x_block_self_attn
|
||||||
|
(
|
||||||
|
shift_msa,
|
||||||
|
scale_msa,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
shift_msa2,
|
||||||
|
scale_msa2,
|
||||||
|
gate_msa2,
|
||||||
|
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||||
|
x_norm = self.norm1(x)
|
||||||
|
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||||
|
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||||
|
return qkv, qkv2, (
|
||||||
|
x,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
gate_msa2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
|
||||||
|
assert not self.pre_only
|
||||||
|
attn1 = self.attn.post_attention(attn)
|
||||||
|
attn2 = self.attn2.post_attention(attn2)
|
||||||
|
out1 = gate_msa.unsqueeze(1) * attn1
|
||||||
|
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||||
|
x = x + out1
|
||||||
|
x = x + out2
|
||||||
|
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||||
|
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
assert not self.pre_only
|
assert not self.pre_only
|
||||||
qkv, intermediates = self.pre_attention(x, c)
|
if self.x_block_self_attn:
|
||||||
attn = optimized_attention(
|
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
|
||||||
qkv,
|
attn, _ = optimized_attention(
|
||||||
num_heads=self.attn.num_heads,
|
qkv[0], qkv[1], qkv[2],
|
||||||
)
|
num_heads=self.attn.num_heads,
|
||||||
return self.post_attention(attn, *intermediates)
|
)
|
||||||
|
attn2, _ = optimized_attention(
|
||||||
|
qkv2[0], qkv2[1], qkv2[2],
|
||||||
|
num_heads=self.attn2.num_heads,
|
||||||
|
)
|
||||||
|
return self.post_attention_x(attn, attn2, *intermediates)
|
||||||
|
else:
|
||||||
|
qkv, intermediates = self.pre_attention(x, c)
|
||||||
|
attn = optimized_attention(
|
||||||
|
qkv[0], qkv[1], qkv[2],
|
||||||
|
heads=self.attn.num_heads,
|
||||||
|
)
|
||||||
|
return self.post_attention(attn, *intermediates)
|
||||||
|
|
||||||
|
|
||||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||||
@@ -569,7 +620,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
|||||||
def _block_mixing(context, x, context_block, x_block, c):
|
def _block_mixing(context, x, context_block, x_block, c):
|
||||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||||
|
|
||||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
if x_block.x_block_self_attn:
|
||||||
|
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||||
|
else:
|
||||||
|
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||||
|
|
||||||
o = []
|
o = []
|
||||||
for t in range(3):
|
for t in range(3):
|
||||||
@@ -577,8 +631,8 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
qkv = tuple(o)
|
qkv = tuple(o)
|
||||||
|
|
||||||
attn = optimized_attention(
|
attn = optimized_attention(
|
||||||
qkv,
|
qkv[0], qkv[1], qkv[2],
|
||||||
num_heads=x_block.attn.num_heads,
|
heads=x_block.attn.num_heads,
|
||||||
)
|
)
|
||||||
context_attn, x_attn = (
|
context_attn, x_attn = (
|
||||||
attn[:, : context_qkv[0].shape[1]],
|
attn[:, : context_qkv[0].shape[1]],
|
||||||
@@ -590,7 +644,14 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
context = None
|
context = None
|
||||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
if x_block.x_block_self_attn:
|
||||||
|
attn2 = optimized_attention(
|
||||||
|
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||||
|
heads=x_block.attn2.num_heads,
|
||||||
|
)
|
||||||
|
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||||
|
else:
|
||||||
|
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||||
return context, x
|
return context, x
|
||||||
|
|
||||||
|
|
||||||
@@ -605,8 +666,13 @@ class JointBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
pre_only = kwargs.pop("pre_only")
|
pre_only = kwargs.pop("pre_only")
|
||||||
qk_norm = kwargs.pop("qk_norm", None)
|
qk_norm = kwargs.pop("qk_norm", None)
|
||||||
|
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
self.x_block = DismantledBlock(*args,
|
||||||
|
pre_only=False,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
x_block_self_attn=x_block_self_attn,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return block_mixing(
|
return block_mixing(
|
||||||
@@ -662,7 +728,7 @@ class SelfAttentionContext(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
qkv = self.qkv(x)
|
qkv = self.qkv(x)
|
||||||
q, k, v = split_qkv(qkv, self.dim_head)
|
q, k, v = split_qkv(qkv, self.dim_head)
|
||||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
|
||||||
return self.proj(x)
|
return self.proj(x)
|
||||||
|
|
||||||
class ContextProcessorBlock(nn.Module):
|
class ContextProcessorBlock(nn.Module):
|
||||||
@@ -721,9 +787,12 @@ class MMDiT(nn.Module):
|
|||||||
qk_norm: Optional[str] = None,
|
qk_norm: Optional[str] = None,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
context_processor_layers = None,
|
context_processor_layers = None,
|
||||||
|
x_block_self_attn: bool = False,
|
||||||
|
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||||
context_size = 4096,
|
context_size = 4096,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
final_layer = True,
|
final_layer = True,
|
||||||
|
skip_blocks = False,
|
||||||
dtype = None, #TODO
|
dtype = None, #TODO
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@@ -738,6 +807,7 @@ class MMDiT(nn.Module):
|
|||||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||||
self.pos_embed_offset = pos_embed_offset
|
self.pos_embed_offset = pos_embed_offset
|
||||||
self.pos_embed_max_size = pos_embed_max_size
|
self.pos_embed_max_size = pos_embed_max_size
|
||||||
|
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||||
|
|
||||||
# hidden_size = default(hidden_size, 64 * depth)
|
# hidden_size = default(hidden_size, 64 * depth)
|
||||||
# num_heads = default(num_heads, hidden_size // 64)
|
# num_heads = default(num_heads, hidden_size // 64)
|
||||||
@@ -795,26 +865,28 @@ class MMDiT(nn.Module):
|
|||||||
self.pos_embed = None
|
self.pos_embed = None
|
||||||
|
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.joint_blocks = nn.ModuleList(
|
if not skip_blocks:
|
||||||
[
|
self.joint_blocks = nn.ModuleList(
|
||||||
JointBlock(
|
[
|
||||||
self.hidden_size,
|
JointBlock(
|
||||||
num_heads,
|
self.hidden_size,
|
||||||
mlp_ratio=mlp_ratio,
|
num_heads,
|
||||||
qkv_bias=qkv_bias,
|
mlp_ratio=mlp_ratio,
|
||||||
attn_mode=attn_mode,
|
qkv_bias=qkv_bias,
|
||||||
pre_only=(i == num_blocks - 1) and final_layer,
|
attn_mode=attn_mode,
|
||||||
rmsnorm=rmsnorm,
|
pre_only=(i == num_blocks - 1) and final_layer,
|
||||||
scale_mod_only=scale_mod_only,
|
rmsnorm=rmsnorm,
|
||||||
swiglu=swiglu,
|
scale_mod_only=scale_mod_only,
|
||||||
qk_norm=qk_norm,
|
swiglu=swiglu,
|
||||||
dtype=dtype,
|
qk_norm=qk_norm,
|
||||||
device=device,
|
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
|
||||||
operations=operations
|
dtype=dtype,
|
||||||
)
|
device=device,
|
||||||
for i in range(num_blocks)
|
operations=operations,
|
||||||
]
|
)
|
||||||
)
|
for i in range(num_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
@@ -877,7 +949,9 @@ class MMDiT(nn.Module):
|
|||||||
c_mod: torch.Tensor,
|
c_mod: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if self.register_length > 0:
|
if self.register_length > 0:
|
||||||
context = torch.cat(
|
context = torch.cat(
|
||||||
(
|
(
|
||||||
@@ -889,14 +963,25 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
# context is B, L', D
|
# context is B, L', D
|
||||||
# x is B, L, D
|
# x is B, L, D
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
blocks = len(self.joint_blocks)
|
blocks = len(self.joint_blocks)
|
||||||
for i in range(blocks):
|
for i in range(blocks):
|
||||||
context, x = self.joint_blocks[i](
|
if ("double_block", i) in blocks_replace:
|
||||||
context,
|
def block_wrap(args):
|
||||||
x,
|
out = {}
|
||||||
c=c_mod,
|
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||||
use_checkpoint=self.use_checkpoint,
|
return out
|
||||||
)
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||||
|
context = out["txt"]
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
context, x = self.joint_blocks[i](
|
||||||
|
context,
|
||||||
|
x,
|
||||||
|
c=c_mod,
|
||||||
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
)
|
||||||
if control is not None:
|
if control is not None:
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
@@ -914,6 +999,7 @@ class MMDiT(nn.Module):
|
|||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass of DiT.
|
Forward pass of DiT.
|
||||||
@@ -935,7 +1021,7 @@ class MMDiT(nn.Module):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
context = self.context_embedder(context)
|
context = self.context_embedder(context)
|
||||||
|
|
||||||
x = self.forward_core_with_concat(x, c, context, control)
|
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
|
||||||
|
|
||||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||||
return x[:,:,:hw[-2],:hw[-1]]
|
return x[:,:,:hw[-2],:hw[-1]]
|
||||||
@@ -949,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
|||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|||||||
@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
|
|||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
if "emb_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["emb_patch"]
|
||||||
|
for p in patch:
|
||||||
|
emb = p(emb, self.model_channels, transformer_options)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|||||||
119
comfy/lora.py
119
comfy/lora.py
@@ -16,6 +16,7 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
@@ -200,9 +201,13 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
for k in sdk:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
|
clip_g_present = False
|
||||||
for b in range(32): #TODO: clean up
|
for b in range(32): #TODO: clean up
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
@@ -226,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
|
clip_g_present = True
|
||||||
if clip_l_present:
|
if clip_l_present:
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
@@ -241,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
t5_index = 1
|
||||||
key_map[lora_key] = k
|
if clip_g_present:
|
||||||
|
t5_index += 1
|
||||||
|
if clip_l_present:
|
||||||
|
t5_index += 1
|
||||||
|
if t5_index == 2:
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||||
|
t5_index += 1
|
||||||
|
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
@@ -280,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||||
|
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||||
|
|
||||||
diffusers_lora_prefix = ["", "unet."]
|
diffusers_lora_prefix = ["", "unet."]
|
||||||
for p in diffusers_lora_prefix:
|
for p in diffusers_lora_prefix:
|
||||||
@@ -302,6 +317,10 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||||
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
@@ -323,14 +342,15 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
weight_calc.transpose(0, 1)
|
weight_calc.transpose(0, 1)
|
||||||
.reshape(weight_calc.shape[1], -1)
|
.reshape(weight_calc.shape[1], -1)
|
||||||
@@ -347,6 +367,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
|||||||
weight[:] = weight_calc
|
weight[:] = weight_calc
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
@@ -366,7 +419,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
@@ -375,12 +428,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
v = v[1]
|
v = v[1]
|
||||||
|
|
||||||
if patch_type == "diff":
|
if patch_type == "diff":
|
||||||
w1 = v[0]
|
diff: torch.Tensor = v[0]
|
||||||
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||||
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||||
|
if do_pad_weight and diff.shape != weight.shape:
|
||||||
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||||
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||||
|
|
||||||
if strength != 0.0:
|
if strength != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if diff.shape != weight.shape:
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
@@ -398,7 +457,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -444,7 +503,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -481,28 +540,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from comfy.ldm.cascade.stage_b import StageB
|
|||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
|
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
||||||
import comfy.ldm.aura.mmdit
|
import comfy.ldm.aura.mmdit
|
||||||
import comfy.ldm.hydit.models
|
import comfy.ldm.hydit.models
|
||||||
import comfy.ldm.audio.dit
|
import comfy.ldm.audio.dit
|
||||||
@@ -96,7 +97,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||||
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@@ -244,6 +246,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
|
||||||
|
if self.model_config.scaled_fp8 is not None:
|
||||||
|
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
@@ -713,3 +719,18 @@ class Flux(BaseModel):
|
|||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class GenmoMochi(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|||||||
@@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||||
if context_processor in state_dict_keys:
|
if context_processor in state_dict_keys:
|
||||||
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||||
|
unet_config["x_block_self_attn_layers"] = []
|
||||||
|
for key in state_dict_keys:
|
||||||
|
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
|
||||||
|
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
|
||||||
|
unet_config["x_block_self_attn_layers"].append(int(layer))
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||||
@@ -145,6 +150,34 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "mochi_preview"
|
||||||
|
dit_config["depth"] = 48
|
||||||
|
dit_config["patch_size"] = 2
|
||||||
|
dit_config["num_heads"] = 24
|
||||||
|
dit_config["hidden_size_x"] = 3072
|
||||||
|
dit_config["hidden_size_y"] = 1536
|
||||||
|
dit_config["mlp_ratio_x"] = 4.0
|
||||||
|
dit_config["mlp_ratio_y"] = 4.0
|
||||||
|
dit_config["learn_sigma"] = False
|
||||||
|
dit_config["in_channels"] = 12
|
||||||
|
dit_config["qk_norm"] = True
|
||||||
|
dit_config["qkv_bias"] = False
|
||||||
|
dit_config["out_bias"] = True
|
||||||
|
dit_config["attn_drop"] = 0.0
|
||||||
|
dit_config["patch_embed_bias"] = True
|
||||||
|
dit_config["posenc_preserve_area"] = True
|
||||||
|
dit_config["timestep_mlp_bias"] = True
|
||||||
|
dit_config["attend_to_padding"] = False
|
||||||
|
dit_config["timestep_scale"] = 1000.0
|
||||||
|
dit_config["use_t5"] = True
|
||||||
|
dit_config["t5_feat_dim"] = 4096
|
||||||
|
dit_config["t5_token_length"] = 256
|
||||||
|
dit_config["rope_theta"] = 10000.0
|
||||||
|
return dit_config
|
||||||
|
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -286,9 +319,15 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
return None
|
return None
|
||||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||||
else:
|
|
||||||
return model_config
|
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
|
||||||
|
if scaled_fp8_weight is not None:
|
||||||
|
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||||
|
if model_config.scaled_fp8 == torch.float32:
|
||||||
|
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
return model_config
|
||||||
|
|
||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||||
@@ -501,7 +540,11 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
out_sd = {}
|
out_sd = {}
|
||||||
|
|
||||||
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||||
|
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||||
|
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||||
|
elif 'x_embedder.weight' in state_dict: #Flux
|
||||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||||
@@ -510,10 +553,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
|
||||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
|
||||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
|
||||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ cpu_state = CPUState.GPU
|
|||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
torch_version = ""
|
||||||
try:
|
try:
|
||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||||
@@ -144,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
|||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
logging.info("pytorch version: {}".format(torch_version))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -325,7 +326,7 @@ class LoadedModel:
|
|||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
@@ -369,12 +370,11 @@ def offloaded_memory(loaded_models, device):
|
|||||||
offloaded_mem += m.model_offloaded_memory()
|
offloaded_mem += m.model_offloaded_memory()
|
||||||
return offloaded_mem
|
return offloaded_mem
|
||||||
|
|
||||||
def minimum_inference_memory():
|
WINDOWS = any(platform.win32_ver())
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
|
||||||
|
|
||||||
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if any(platform.win32_ver()):
|
if WINDOWS:
|
||||||
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
@@ -383,6 +383,9 @@ if args.reserve_vram is not None:
|
|||||||
def extra_reserved_memory():
|
def extra_reserved_memory():
|
||||||
return EXTRA_RESERVED_VRAM
|
return EXTRA_RESERVED_VRAM
|
||||||
|
|
||||||
|
def minimum_inference_memory():
|
||||||
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||||
to_unload = []
|
to_unload = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
@@ -405,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
|||||||
if not force_unload:
|
if not force_unload:
|
||||||
if unload_weights_only and unload_weight == False:
|
if unload_weights_only and unload_weight == False:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
unload_weight = True
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||||
@@ -421,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model.device == device:
|
if shift_model.device == device:
|
||||||
if shift_model not in keep_loaded:
|
if shift_model not in keep_loaded:
|
||||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
@@ -621,6 +626,8 @@ def maximum_vram_for_weights(device=None):
|
|||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
|
if model_params < 0:
|
||||||
|
model_params = 1000000000000000000000
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if args.fp16_unet:
|
if args.fp16_unet:
|
||||||
@@ -640,6 +647,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if fp8_dtype is not None:
|
if fp8_dtype is not None:
|
||||||
|
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||||
|
return fp8_dtype
|
||||||
|
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if model_params * 2 > free_model_memory:
|
if model_params * 2 > free_model_memory:
|
||||||
return fp8_dtype
|
return fp8_dtype
|
||||||
@@ -833,27 +843,21 @@ def force_channels_last():
|
|||||||
#TODO
|
#TODO
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
return r
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
non_blocking = device_supports_non_blocking(device)
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
device_supports_cast = True
|
|
||||||
elif tensor.dtype == torch.bfloat16:
|
|
||||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
|
||||||
device_supports_cast = True
|
|
||||||
elif is_intel_xpu():
|
|
||||||
device_supports_cast = True
|
|
||||||
|
|
||||||
non_blocking = device_should_use_non_blocking(device)
|
|
||||||
|
|
||||||
if device_supports_cast:
|
|
||||||
if copy:
|
|
||||||
if tensor.device == device:
|
|
||||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
@@ -892,7 +896,7 @@ def force_upcast_attention_dtype():
|
|||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -999,7 +1003,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||||
for x in nvidia_10_series:
|
for x in nvidia_10_series:
|
||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
return True
|
if WINDOWS or manual_cast:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False #weird linux behavior where fp32 is faster
|
||||||
|
|
||||||
if manual_cast:
|
if manual_cast:
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
@@ -1055,6 +1062,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_fp8_compute(device=None):
|
def supports_fp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 9:
|
if props.major >= 9:
|
||||||
return True
|
return True
|
||||||
@@ -1062,6 +1072,14 @@ def supports_fp8_compute(device=None):
|
|||||||
return False
|
return False
|
||||||
if props.minor < 9:
|
if props.minor < 9:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if WINDOWS:
|
||||||
|
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
|
|||||||
@@ -28,8 +28,20 @@ import comfy.utils
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
from comfy.types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
|
||||||
|
def string_to_seed(data):
|
||||||
|
crc = 0xFFFFFFFF
|
||||||
|
for byte in data:
|
||||||
|
if isinstance(byte, str):
|
||||||
|
byte = ord(byte)
|
||||||
|
crc ^= byte
|
||||||
|
for _ in range(8):
|
||||||
|
if crc & 1:
|
||||||
|
crc = (crc >> 1) ^ 0xEDB88320
|
||||||
|
else:
|
||||||
|
crc >>= 1
|
||||||
|
return crc ^ 0xFFFFFFFF
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
@@ -76,7 +88,36 @@ class LowVramPatch:
|
|||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
intermediate_dtype = weight.dtype
|
||||||
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||||
|
intermediate_dtype = torch.float32
|
||||||
|
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
||||||
|
|
||||||
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
|
def get_key_weight(model, key):
|
||||||
|
set_func = None
|
||||||
|
convert_func = None
|
||||||
|
op_keys = key.rsplit('.', 1)
|
||||||
|
if len(op_keys) < 2:
|
||||||
|
weight = comfy.utils.get_attr(model, key)
|
||||||
|
else:
|
||||||
|
op = comfy.utils.get_attr(model, op_keys[0])
|
||||||
|
try:
|
||||||
|
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
weight = getattr(op, op_keys[1])
|
||||||
|
if convert_func is not None:
|
||||||
|
weight = comfy.utils.get_attr(model, key)
|
||||||
|
|
||||||
|
return weight, set_func, convert_func
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
@@ -271,17 +312,23 @@ class ModelPatcher:
|
|||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
def get_key_patches(self, filter_prefix=None):
|
||||||
comfy.model_management.unload_model_clones(self)
|
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
p = {}
|
p = {}
|
||||||
for k in model_sd:
|
for k in model_sd:
|
||||||
if filter_prefix is not None:
|
if filter_prefix is not None:
|
||||||
if not k.startswith(filter_prefix):
|
if not k.startswith(filter_prefix):
|
||||||
continue
|
continue
|
||||||
|
bk = self.backup.get(k, None)
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, k)
|
||||||
|
if bk is not None:
|
||||||
|
weight = bk.weight
|
||||||
|
if convert_func is None:
|
||||||
|
convert_func = lambda a, **kwargs: a
|
||||||
|
|
||||||
if k in self.patches:
|
if k in self.patches:
|
||||||
p[k] = [model_sd[k]] + self.patches[k]
|
p[k] = [(weight, convert_func)] + self.patches[k]
|
||||||
else:
|
else:
|
||||||
p[k] = (model_sd[k],)
|
p[k] = [(weight, convert_func)]
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
@@ -297,8 +344,7 @@ class ModelPatcher:
|
|||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update or inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
@@ -308,23 +354,38 @@ class ModelPatcher:
|
|||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
if convert_func is not None:
|
||||||
|
temp_weight = convert_func(temp_weight, inplace=True)
|
||||||
|
|
||||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype)
|
if set_func is None:
|
||||||
if inplace_update:
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
if inplace_update:
|
||||||
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
load_completely = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m))
|
||||||
|
|
||||||
|
load_completely = []
|
||||||
|
loading.sort(reverse=True)
|
||||||
|
for x in loading:
|
||||||
|
n = x[1]
|
||||||
|
m = x[2]
|
||||||
|
module_mem = x[0]
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
@@ -356,9 +417,8 @@ class ModelPatcher:
|
|||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
mem_used = comfy.model_management.module_size(m)
|
mem_counter += module_mem
|
||||||
mem_counter += mem_used
|
load_completely.append((module_mem, n, m))
|
||||||
load_completely.append((mem_used, n, m))
|
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
|
|||||||
99
comfy/ops.py
99
comfy/ops.py
@@ -19,16 +19,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.float
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True):
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
|
||||||
return weight
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
@@ -43,12 +39,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
has_function = s.bias_function is not None
|
has_function = s.bias_function is not None
|
||||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
|
|
||||||
has_function = s.weight_function is not None
|
has_function = s.weight_function is not None
|
||||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
@@ -254,20 +250,29 @@ def fp8_linear(self, input):
|
|||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
tensor_2d = False
|
||||||
|
if len(input.shape) == 2:
|
||||||
|
tensor_2d = True
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
if scale_weight is None:
|
if scale_weight is None:
|
||||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
if scale_input is None:
|
else:
|
||||||
scale_input = scale_weight
|
scale_weight = scale_weight.to(input.device)
|
||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
else:
|
||||||
|
scale_input = scale_input.to(input.device)
|
||||||
|
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
@@ -277,7 +282,11 @@ def fp8_linear(self, input):
|
|||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
|
if tensor_2d:
|
||||||
|
return o.reshape(input.shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class fp8_ops(manual_cast):
|
class fp8_ops(manual_cast):
|
||||||
@@ -295,11 +304,63 @@ class fp8_ops(manual_cast):
|
|||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||||
|
class scaled_fp8_op(manual_cast):
|
||||||
|
class Linear(manual_cast.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if override_dtype is not None:
|
||||||
|
kwargs['dtype'] = override_dtype
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if not hasattr(self, 'scale_weight'):
|
||||||
|
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
|
|
||||||
|
if not scale_input:
|
||||||
|
self.scale_input = None
|
||||||
|
|
||||||
|
if not hasattr(self, 'scale_input'):
|
||||||
|
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
if fp8_matrix_mult:
|
||||||
|
out = fp8_linear(self, input)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
|
||||||
|
if weight.numel() < input.numel(): #TODO: optimize
|
||||||
|
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||||
|
|
||||||
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
|
if inplace:
|
||||||
|
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
return weight
|
||||||
|
else:
|
||||||
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
|
||||||
|
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||||
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||||
|
if inplace_update:
|
||||||
|
self.weight.data.copy_(weight)
|
||||||
|
else:
|
||||||
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
|
return scaled_fp8_op
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||||
|
|
||||||
|
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||||
|
return fp8_ops
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
if args.fast:
|
|
||||||
if comfy.model_management.supports_fp8_compute(load_device):
|
|
||||||
return fp8_ops
|
|
||||||
return manual_cast
|
return manual_cast
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from comfy import model_management
|
|||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import scipy
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
@@ -358,11 +358,35 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
|||||||
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||||
|
|
||||||
sigs = []
|
sigs = []
|
||||||
|
last_t = -1
|
||||||
for t in ts:
|
for t in ts:
|
||||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
if t != last_t:
|
||||||
|
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||||
|
last_t = t
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
|
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
|
||||||
|
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
|
||||||
|
if steps == 1:
|
||||||
|
sigma_schedule = [1.0, 0.0]
|
||||||
|
else:
|
||||||
|
if linear_steps is None:
|
||||||
|
linear_steps = steps // 2
|
||||||
|
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||||
|
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
||||||
|
quadratic_steps = steps - linear_steps
|
||||||
|
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||||||
|
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||||||
|
const = quadratic_coef * (linear_steps ** 2)
|
||||||
|
quadratic_sigma_schedule = [
|
||||||
|
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||||||
|
for i in range(linear_steps, steps)
|
||||||
|
]
|
||||||
|
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||||
|
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||||
|
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||||
|
|
||||||
def get_mask_aabb(masks):
|
def get_mask_aabb(masks):
|
||||||
if masks.numel() == 0:
|
if masks.numel() == 0:
|
||||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||||
@@ -570,8 +594,8 @@ class Sampler:
|
|||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis"]
|
"ipndm", "ipndm_v", "deis"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
@@ -729,7 +753,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
|
||||||
|
|
||||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||||
@@ -747,6 +771,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
|||||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||||
elif scheduler_name == "beta":
|
elif scheduler_name == "beta":
|
||||||
sigmas = beta_scheduler(model_sampling, steps)
|
sigmas = beta_scheduler(model_sampling, steps)
|
||||||
|
elif scheduler_name == "linear_quadratic":
|
||||||
|
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
||||||
else:
|
else:
|
||||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|||||||
164
comfy/sd.py
164
comfy/sd.py
@@ -7,6 +7,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
|||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
|
import comfy.ldm.genmo.vae.model
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@@ -25,11 +26,11 @@ import comfy.text_encoders.aura_t5
|
|||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
import comfy.text_encoders.long_clipl
|
import comfy.text_encoders.long_clipl
|
||||||
|
import comfy.text_encoders.genmo
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.supported_models_base
|
|
||||||
import comfy.taesd.taesd
|
import comfy.taesd.taesd
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
@@ -70,14 +71,14 @@ class CLIP:
|
|||||||
clip = target.clip
|
clip = target.clip
|
||||||
tokenizer = target.tokenizer
|
tokenizer = target.tokenizer
|
||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_options.get("load_device", model_management.text_encoder_device())
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
|
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
|
||||||
params['model_options'] = model_options
|
params['model_options'] = model_options
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
@@ -170,6 +171,7 @@ class VAE:
|
|||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
|
self.latent_dim = 2
|
||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
@@ -239,9 +241,22 @@ class VAE:
|
|||||||
self.output_channels = 2
|
self.output_channels = 2
|
||||||
self.upscale_ratio = 2048
|
self.upscale_ratio = 2048
|
||||||
self.downscale_ratio = 2048
|
self.downscale_ratio = 2048
|
||||||
|
self.latent_dim = 1
|
||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae
|
||||||
|
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||||
|
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
|
||||||
|
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
|
||||||
|
self.latent_channels = 12
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||||
|
self.working_dtypes = [torch.float16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@@ -297,6 +312,10 @@ class VAE:
|
|||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||||
|
|
||||||
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
@@ -315,6 +334,7 @@ class VAE:
|
|||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
@@ -322,16 +342,21 @@ class VAE:
|
|||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||||
|
if pixel_samples is None:
|
||||||
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
|
pixel_samples[x:x+batch_number] = out
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
if len(samples_in.shape) == 3:
|
dims = samples_in.ndim - 2
|
||||||
|
if dims == 1:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
else:
|
elif dims == 2:
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
elif dims == 3:
|
||||||
|
pixel_samples = self.decode_tiled_3d(samples_in)
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@@ -343,17 +368,22 @@ class VAE:
|
|||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
|
if self.latent_dim == 3:
|
||||||
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
samples = None
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||||
|
if samples is None:
|
||||||
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
@@ -399,6 +429,7 @@ class CLIPType(Enum):
|
|||||||
STABLE_AUDIO = 4
|
STABLE_AUDIO = 4
|
||||||
HUNYUAN_DIT = 5
|
HUNYUAN_DIT = 5
|
||||||
FLUX = 6
|
FLUX = 6
|
||||||
|
MOCHI = 7
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
@@ -406,8 +437,46 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class TEModel(Enum):
|
||||||
|
CLIP_L = 1
|
||||||
|
CLIP_H = 2
|
||||||
|
CLIP_G = 3
|
||||||
|
T5_XXL = 4
|
||||||
|
T5_XL = 5
|
||||||
|
T5_BASE = 6
|
||||||
|
|
||||||
|
def detect_te_model(sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_G
|
||||||
|
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_H
|
||||||
|
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_L
|
||||||
|
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||||
|
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
|
if weight.shape[-1] == 4096:
|
||||||
|
return TEModel.T5_XXL
|
||||||
|
elif weight.shape[-1] == 2048:
|
||||||
|
return TEModel.T5_XL
|
||||||
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
|
return TEModel.T5_BASE
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def t5xxl_detect(clip_data):
|
||||||
|
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||||
|
|
||||||
|
for sd in clip_data:
|
||||||
|
if weight_name in sd:
|
||||||
|
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -421,64 +490,65 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
clip_target.params = {}
|
clip_target.params = {}
|
||||||
if len(clip_data) == 1:
|
if len(clip_data) == 1:
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
te_model = detect_te_model(clip_data[0])
|
||||||
|
if te_model == TEModel.CLIP_G:
|
||||||
if clip_type == CLIPType.STABLE_CASCADE:
|
if clip_type == CLIPType.STABLE_CASCADE:
|
||||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||||
|
elif clip_type == CLIPType.SD3:
|
||||||
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
elif te_model == TEModel.CLIP_H:
|
||||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
elif te_model == TEModel.T5_XXL:
|
||||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
if clip_type == CLIPType.SD3:
|
||||||
dtype_t5 = weight.dtype
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||||
if weight.shape[-1] == 4096:
|
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif weight.shape[-1] == 2048:
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
elif te_model == TEModel.T5_XL:
|
||||||
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
|
elif te_model == TEModel.T5_BASE:
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
if clip_type == CLIPType.SD3:
|
||||||
if w is not None and w.shape[0] == 248:
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
|
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
elif len(clip_data) == 2:
|
elif len(clip_data) == 2:
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||||
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
||||||
elif clip_type == CLIPType.FLUX:
|
elif clip_type == CLIPType.FLUX:
|
||||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
||||||
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
|
|
||||||
dtype_t5 = None
|
|
||||||
if weight is not None:
|
|
||||||
dtype_t5 = weight.dtype
|
|
||||||
|
|
||||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif len(clip_data) == 3:
|
elif len(clip_data) == 3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
|
tokenizer_data = {}
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@@ -544,11 +614,11 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None:
|
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
unet_dtype = model_options.get("weight_dtype", None)
|
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||||
|
|
||||||
if unet_dtype is None:
|
if unet_dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
@@ -562,7 +632,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model.load_model_weights(sd, diffusion_model_prefix)
|
model.load_model_weights(sd, diffusion_model_prefix)
|
||||||
|
|
||||||
@@ -614,6 +683,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
|
|
||||||
@@ -640,14 +711,21 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||||
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
|
if model_options.get("fp8_optimizations", False):
|
||||||
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
"pooled",
|
"pooled",
|
||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||||
@@ -94,11 +94,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
|
scaled_fp8 = None
|
||||||
|
|
||||||
if operations is None:
|
if operations is None:
|
||||||
operations = comfy.ops.manual_cast
|
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||||
|
else:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||||
|
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@@ -542,6 +551,7 @@ class SD1Tokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -570,6 +580,7 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
|
||||||
|
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SDXLTokenizer:
|
class SDXLTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -40,7 +41,8 @@ class SDXLTokenizer:
|
|||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = set([dtype])
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
token_weight_pairs_l = token_weight_pairs["l"]
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
cut_to = min(l_out.shape[1], g_out.shape[1])
|
||||||
|
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import comfy.text_encoders.sa_t5
|
|||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
|
import comfy.text_encoders.genmo
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@@ -529,12 +530,11 @@ class SD3(supported_models_base.BASE):
|
|||||||
clip_l = True
|
clip_l = True
|
||||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
clip_g = True
|
clip_g = True
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
if t5_key in state_dict:
|
if "dtype_t5" in t5_detect:
|
||||||
t5 = True
|
t5 = True
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
|
||||||
|
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
|
||||||
|
|
||||||
class StableAudio(supported_models_base.BASE):
|
class StableAudio(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@@ -653,11 +653,8 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
dtype_t5 = None
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||||
if t5_key in state_dict:
|
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
|
||||||
|
|
||||||
class FluxSchnell(Flux):
|
class FluxSchnell(Flux):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@@ -674,7 +671,36 @@ class FluxSchnell(Flux):
|
|||||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class GenmoMochi(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "mochi_preview",
|
||||||
|
}
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 6.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Mochi
|
||||||
|
|
||||||
|
memory_usage_factor = 2.0 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.GenmoMochi(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ class BASE:
|
|||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
|
scaled_fp8 = None
|
||||||
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
@@ -71,6 +73,7 @@ class BASE:
|
|||||||
self.unet_config = unet_config.copy()
|
self.unet_config = unet_config.copy()
|
||||||
self.sampling_settings = self.sampling_settings.copy()
|
self.sampling_settings = self.sampling_settings.copy()
|
||||||
self.latent_format = self.latent_format()
|
self.latent_format = self.latent_format()
|
||||||
|
self.optimizations = self.optimizations.copy()
|
||||||
for x in self.unet_extra_config:
|
for x in self.unet_extra_config:
|
||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,21 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import comfy.text_encoders.t5
|
import comfy.text_encoders.t5
|
||||||
|
import comfy.text_encoders.sd3_clip
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
class FluxTokenizer:
|
class FluxTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -38,8 +35,9 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
|
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes = set([dtype, dtype_t5])
|
self.dtypes = set([dtype, dtype_t5])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
@@ -64,8 +62,11 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def flux_clip(dtype_t5=None):
|
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
class FluxClipModel_(FluxClipModel):
|
class FluxClipModel_(FluxClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
|||||||
38
comfy/text_encoders/genmo.py
Normal file
38
comfy/text_encoders/genmo.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.sd3_clip
|
||||||
|
import os
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs["attention_mask"] = True
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
|
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
|
class MochiTEModel_(MochiT5XXL):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
if dtype is None:
|
||||||
|
dtype = dtype_t5
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return MochiTEModel_
|
||||||
@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
|||||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class LongClipModel_(sd1_clip.SDClipModel):
|
class LongClipModel_(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, *args, **kwargs):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||||
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||||
|
|
||||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||||
|
|
||||||
|
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||||
|
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is None:
|
||||||
|
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is not None and w.shape[0] == 248:
|
||||||
|
tokenizer_data = tokenizer_data.copy()
|
||||||
|
model_options = model_options.copy()
|
||||||
|
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||||
|
model_options["clip_l_class"] = LongClipModel_
|
||||||
|
return tokenizer_data, model_options
|
||||||
|
|||||||
@@ -8,9 +8,27 @@ import comfy.model_management
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||||
|
if t5xxl_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def t5_xxl_detect(state_dict, prefix=""):
|
||||||
|
out = {}
|
||||||
|
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
|
||||||
|
if t5_key in state_dict:
|
||||||
|
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
|
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||||
|
if scaled_fp8_key in state_dict:
|
||||||
|
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -20,7 +38,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
@@ -38,11 +57,12 @@ class SD3Tokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_l = None
|
self.clip_l = None
|
||||||
@@ -55,7 +75,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
if t5:
|
if t5:
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
self.t5_attention_mask = t5_attention_mask
|
||||||
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
|
||||||
self.dtypes.add(dtype_t5)
|
self.dtypes.add(dtype_t5)
|
||||||
else:
|
else:
|
||||||
self.t5xxl = None
|
self.t5xxl = None
|
||||||
@@ -85,6 +106,7 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
lg_out = None
|
lg_out = None
|
||||||
pooled = None
|
pooled = None
|
||||||
out = None
|
out = None
|
||||||
|
extra = {}
|
||||||
|
|
||||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||||
if self.clip_l is not None:
|
if self.clip_l is not None:
|
||||||
@@ -95,7 +117,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
if self.clip_g is not None:
|
if self.clip_g is not None:
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
cut_to = min(lg_out.shape[1], g_out.shape[1])
|
||||||
|
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
|
||||||
else:
|
else:
|
||||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||||
else:
|
else:
|
||||||
@@ -108,7 +131,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
if self.t5xxl is not None:
|
if self.t5xxl is not None:
|
||||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||||
|
t5_out, t5_pooled = t5_output[:2]
|
||||||
|
if self.t5_attention_mask:
|
||||||
|
extra["attention_mask"] = t5_output[2]["attention_mask"]
|
||||||
|
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
else:
|
else:
|
||||||
@@ -120,7 +147,7 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
if pooled is None:
|
if pooled is None:
|
||||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
return out, pooled
|
return out, pooled, extra
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@@ -130,8 +157,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
106
comfy/utils.py
106
comfy/utils.py
@@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
|
|||||||
for k in sd.keys():
|
for k in sd.keys():
|
||||||
if k.startswith(prefix):
|
if k.startswith(prefix):
|
||||||
w = sd[k]
|
w = sd[k]
|
||||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||||
|
|
||||||
if len(dtypes) == 0:
|
if len(dtypes) == 0:
|
||||||
return None
|
return None
|
||||||
@@ -316,10 +316,18 @@ MMDIT_MAP_BLOCK = {
|
|||||||
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
||||||
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
||||||
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
||||||
|
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
|
||||||
|
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
|
||||||
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
||||||
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
||||||
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
||||||
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
||||||
|
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
|
||||||
|
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
|
||||||
|
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
|
||||||
|
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
|
||||||
|
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
|
||||||
|
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
|
||||||
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||||
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||||
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
||||||
@@ -349,6 +357,12 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||||
|
|
||||||
|
k = "{}.attn2.".format(block_from)
|
||||||
|
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
|
||||||
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||||
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||||
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||||
|
|
||||||
for k in MMDIT_MAP_BLOCK:
|
for k in MMDIT_MAP_BLOCK:
|
||||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||||
|
|
||||||
@@ -528,6 +542,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
||||||
|
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
@@ -688,9 +704,14 @@ def lanczos(samples, width, height):
|
|||||||
return result.to(samples.device, samples.dtype)
|
return result.to(samples.device, samples.dtype)
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
|
orig_shape = tuple(samples.shape)
|
||||||
|
if len(orig_shape) > 4:
|
||||||
|
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
||||||
|
samples = samples.movedim(2, 1)
|
||||||
|
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
old_width = samples.shape[3]
|
old_width = samples.shape[-1]
|
||||||
old_height = samples.shape[2]
|
old_height = samples.shape[-2]
|
||||||
old_aspect = old_width / old_height
|
old_aspect = old_width / old_height
|
||||||
new_aspect = width / height
|
new_aspect = width / height
|
||||||
x = 0
|
x = 0
|
||||||
@@ -699,48 +720,87 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
|||||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||||
elif old_aspect < new_aspect:
|
elif old_aspect < new_aspect:
|
||||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
||||||
else:
|
else:
|
||||||
s = samples
|
s = samples
|
||||||
|
|
||||||
if upscale_method == "bislerp":
|
if upscale_method == "bislerp":
|
||||||
return bislerp(s, width, height)
|
out = bislerp(s, width, height)
|
||||||
elif upscale_method == "lanczos":
|
elif upscale_method == "lanczos":
|
||||||
return lanczos(s, width, height)
|
out = lanczos(s, width, height)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
|
if len(orig_shape) == 4:
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
|
||||||
|
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
|
||||||
|
|
||||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
||||||
|
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
||||||
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
dims = len(tile)
|
dims = len(tile)
|
||||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
|
||||||
|
if not (isinstance(upscale_amount, (tuple, list))):
|
||||||
|
upscale_amount = [upscale_amount] * dims
|
||||||
|
|
||||||
|
if not (isinstance(overlap, (tuple, list))):
|
||||||
|
overlap = [overlap] * dims
|
||||||
|
|
||||||
|
def get_upscale(dim, val):
|
||||||
|
up = upscale_amount[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return up * val
|
||||||
|
|
||||||
|
def mult_list_upscale(a):
|
||||||
|
out = []
|
||||||
|
for i in range(len(a)):
|
||||||
|
out.append(round(get_upscale(i, a[i])))
|
||||||
|
return out
|
||||||
|
|
||||||
|
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
|
||||||
|
|
||||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
# handle entire input fitting in a single tile
|
||||||
|
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||||
|
output[b:b+1] = function(s).to(output_device)
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
|
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
|
|
||||||
|
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
|
for it in itertools.product(*positions):
|
||||||
s_in = s
|
s_in = s
|
||||||
upscaled = []
|
upscaled = []
|
||||||
|
|
||||||
for d in range(dims):
|
for d in range(dims):
|
||||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
|
||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(pos * upscale_amount))
|
upscaled.append(round(get_upscale(d, pos)))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
|
||||||
for t in range(feather):
|
for d in range(2, dims + 2):
|
||||||
for d in range(2, dims + 2):
|
feather = round(get_upscale(d - 2, overlap[d - 2]))
|
||||||
m = mask.narrow(d, t, 1)
|
for t in range(feather):
|
||||||
m *= ((1.0/feather) * (t + 1))
|
a = (t + 1) / feather
|
||||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
mask.narrow(d, t, 1).mul_(a)
|
||||||
m *= ((1.0/feather) * (t + 1))
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||||
|
|
||||||
o = out
|
o = out
|
||||||
o_d = out_div
|
o_d = out_div
|
||||||
@@ -748,8 +808,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
o += ps * mask
|
o.add_(ps * mask)
|
||||||
o_d += mask
|
o_d.add_(mask)
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence, Mapping
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def include_unique_id_in_input(class_type: str) -> bool:
|
||||||
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
class CacheKeySet:
|
class CacheKeySet:
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
@@ -56,6 +66,8 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = (node_id, node["class_type"])
|
self.keys[node_id] = (node_id, node["class_type"])
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
@@ -74,6 +86,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
@@ -87,11 +101,14 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return to_hashable(signature)
|
return to_hashable(signature)
|
||||||
|
|
||||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
# This node doesn't exist -- we can't cache it.
|
||||||
|
return [float("NaN")]
|
||||||
node = dynprompt.get_node(node_id)
|
node = dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||||
signature.append(node_id)
|
signature.append(node_id)
|
||||||
inputs = node["inputs"]
|
inputs = node["inputs"]
|
||||||
for key in sorted(inputs.keys()):
|
for key in sorted(inputs.keys()):
|
||||||
@@ -112,6 +129,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return ancestors, order_mapping
|
return ancestors, order_mapping
|
||||||
|
|
||||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
return
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
input_keys = sorted(inputs.keys())
|
input_keys = sorted(inputs.keys())
|
||||||
for key in input_keys:
|
for key in input_keys:
|
||||||
|
|||||||
@@ -99,30 +99,44 @@ class TopologicalSort:
|
|||||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
self.add_node(from_node_id)
|
if not self.is_cached(from_node_id):
|
||||||
if to_node_id not in self.blocking[from_node_id]:
|
self.add_node(from_node_id)
|
||||||
self.blocking[from_node_id][to_node_id] = {}
|
if to_node_id not in self.blocking[from_node_id]:
|
||||||
self.blockCount[to_node_id] += 1
|
self.blocking[from_node_id][to_node_id] = {}
|
||||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
self.blockCount[to_node_id] += 1
|
||||||
|
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||||
|
|
||||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||||
if unique_id in self.pendingNodes:
|
node_ids = [node_unique_id]
|
||||||
return
|
links = []
|
||||||
self.pendingNodes[unique_id] = True
|
|
||||||
self.blockCount[unique_id] = 0
|
|
||||||
self.blocking[unique_id] = {}
|
|
||||||
|
|
||||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
while len(node_ids) > 0:
|
||||||
for input_name in inputs:
|
unique_id = node_ids.pop()
|
||||||
value = inputs[input_name]
|
if unique_id in self.pendingNodes:
|
||||||
if is_link(value):
|
continue
|
||||||
from_node_id, from_socket = value
|
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
self.pendingNodes[unique_id] = True
|
||||||
continue
|
self.blockCount[unique_id] = 0
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
self.blocking[unique_id] = {}
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
|
||||||
if include_lazy or not is_lazy:
|
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
for input_name in inputs:
|
||||||
|
value = inputs[input_name]
|
||||||
|
if is_link(value):
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
|
continue
|
||||||
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||||
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
|
node_ids.append(from_node_id)
|
||||||
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
self.add_strong_link(*link)
|
||||||
|
|
||||||
|
def is_cached(self, node_id):
|
||||||
|
return False
|
||||||
|
|
||||||
def get_ready_nodes(self):
|
def get_ready_nodes(self):
|
||||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||||
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.output_cache = output_cache
|
self.output_cache = output_cache
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def is_cached(self, node_id):
|
||||||
if self.output_cache.get(from_node_id) is not None:
|
return self.output_cache.get(node_id) is not None
|
||||||
# Nothing to do
|
|
||||||
return
|
|
||||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
|
||||||
|
|
||||||
def stage_node_execution(self):
|
def stage_node_execution(self):
|
||||||
assert self.staged_node_id is None
|
assert self.staged_node_id is None
|
||||||
|
|||||||
@@ -16,14 +16,15 @@ class EmptyLatentAudio:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
def generate(self, seconds):
|
def generate(self, seconds, batch_size):
|
||||||
batch_size = 1
|
|
||||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||||
return ({"samples":latent, "type": "audio"}, )
|
return ({"samples":latent, "type": "audio"}, )
|
||||||
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
|
|||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
|
std[std < 1.0] = 1.0
|
||||||
|
audio /= std
|
||||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||||
|
|
||||||
|
|
||||||
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
|
|||||||
}
|
}
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio:
|
||||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [
|
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||||
f for f in os.listdir(input_dir)
|
|
||||||
if (os.path.isfile(os.path.join(input_dir, f))
|
|
||||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||||
|
|
||||||
CATEGORY = "audio"
|
CATEGORY = "audio"
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||||
|
import nodes
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class SetUnionControlNetType:
|
class SetUnionControlNetType:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
|
|||||||
|
|
||||||
return (control_net,)
|
return (control_net,)
|
||||||
|
|
||||||
|
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
FUNCTION = "apply_inpaint_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
|
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||||
|
extra_concat = []
|
||||||
|
if control_net.concat_mask:
|
||||||
|
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
|
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
||||||
|
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||||
|
extra_concat = [mask]
|
||||||
|
|
||||||
|
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SetUnionControlNetType": SetUnionControlNetType,
|
"SetUnionControlNetType": SetUnionControlNetType,
|
||||||
|
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
|
|||||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
class LaplaceScheduler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
|
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
|
||||||
|
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
|
||||||
|
return (sigmas, )
|
||||||
|
|
||||||
|
|
||||||
class SDTurboScheduler:
|
class SDTurboScheduler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"KarrasScheduler": KarrasScheduler,
|
"KarrasScheduler": KarrasScheduler,
|
||||||
"ExponentialScheduler": ExponentialScheduler,
|
"ExponentialScheduler": ExponentialScheduler,
|
||||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||||
|
"LaplaceScheduler": LaplaceScheduler,
|
||||||
"VPScheduler": VPScheduler,
|
"VPScheduler": VPScheduler,
|
||||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||||
"SDTurboScheduler": SDTurboScheduler,
|
"SDTurboScheduler": SDTurboScheduler,
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class HypernetworkLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy_extras.nodes_post_processing
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent):
|
def reshape_latent_to(target_shape, latent):
|
||||||
@@ -145,6 +146,131 @@ class LatentBatchSeedBehavior:
|
|||||||
|
|
||||||
return (samples_out,)
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentApplyOperation:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"operation": ("LATENT_OPERATION",),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, samples, operation):
|
||||||
|
samples_out = samples.copy()
|
||||||
|
|
||||||
|
s1 = samples["samples"]
|
||||||
|
samples_out["samples"] = operation(latent=s1)
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentApplyOperationCFG:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"operation": ("LATENT_OPERATION",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model, operation):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
def pre_cfg_function(args):
|
||||||
|
conds_out = args["conds_out"]
|
||||||
|
if len(conds_out) == 2:
|
||||||
|
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
|
||||||
|
else:
|
||||||
|
conds_out[0] = operation(latent=conds_out[0])
|
||||||
|
return conds_out
|
||||||
|
|
||||||
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
class LatentOperationTonemapReinhard:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, multiplier):
|
||||||
|
def tonemap_reinhard(latent, **kwargs):
|
||||||
|
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||||
|
normalized_latent = latent / latent_vector_magnitude
|
||||||
|
|
||||||
|
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||||
|
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||||
|
|
||||||
|
top = (std * 5 + mean) * multiplier
|
||||||
|
|
||||||
|
#reinhard
|
||||||
|
latent_vector_magnitude *= (1.0 / top)
|
||||||
|
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
|
||||||
|
new_magnitude *= top
|
||||||
|
|
||||||
|
return normalized_latent * new_magnitude
|
||||||
|
return (tonemap_reinhard,)
|
||||||
|
|
||||||
|
class LatentOperationSharpen:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"sharpen_radius": ("INT", {
|
||||||
|
"default": 9,
|
||||||
|
"min": 1,
|
||||||
|
"max": 31,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
"sigma": ("FLOAT", {
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.1,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.1
|
||||||
|
}),
|
||||||
|
"alpha": ("FLOAT", {
|
||||||
|
"default": 0.1,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 5.0,
|
||||||
|
"step": 0.01
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, sharpen_radius, sigma, alpha):
|
||||||
|
def sharpen(latent, **kwargs):
|
||||||
|
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||||
|
normalized_latent = latent / luminance
|
||||||
|
channels = latent.shape[1]
|
||||||
|
|
||||||
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
|
kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
|
||||||
|
center = kernel_size // 2
|
||||||
|
|
||||||
|
kernel *= alpha * -10
|
||||||
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||||
|
|
||||||
|
padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||||
|
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||||
|
|
||||||
|
return luminance * sharpened
|
||||||
|
return (sharpen,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentAdd": LatentAdd,
|
"LatentAdd": LatentAdd,
|
||||||
"LatentSubtract": LatentSubtract,
|
"LatentSubtract": LatentSubtract,
|
||||||
@@ -152,4 +278,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentInterpolate": LatentInterpolate,
|
"LatentInterpolate": LatentInterpolate,
|
||||||
"LatentBatch": LatentBatch,
|
"LatentBatch": LatentBatch,
|
||||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||||
|
"LatentApplyOperation": LatentApplyOperation,
|
||||||
|
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
||||||
|
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
||||||
|
"LatentOperationSharpen": LatentOperationSharpen,
|
||||||
}
|
}
|
||||||
|
|||||||
119
comfy_extras/nodes_lora_extract.py
Normal file
119
comfy_extras/nodes_lora_extract.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
def extract_lora(diff, rank):
|
||||||
|
conv2d = (len(diff.shape) == 4)
|
||||||
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
out_dim, in_dim = diff.size()[0:2]
|
||||||
|
rank = min(rank, in_dim, out_dim)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
if conv2d_3x3:
|
||||||
|
diff = diff.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
diff = diff.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(diff.float())
|
||||||
|
U = U[:, :rank]
|
||||||
|
S = S[:rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
|
U = U.clamp(low_val, hi_val)
|
||||||
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
return (U, Vh)
|
||||||
|
|
||||||
|
class LORAType(Enum):
|
||||||
|
STANDARD = 0
|
||||||
|
FULL_DIFF = 1
|
||||||
|
|
||||||
|
LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||||
|
"full_diff": LORAType.FULL_DIFF}
|
||||||
|
|
||||||
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||||
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||||
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
weight_diff = sd[k]
|
||||||
|
if lora_type == LORAType.STANDARD:
|
||||||
|
if weight_diff.ndim < 2:
|
||||||
|
if bias_diff:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out = extract_lora(weight_diff, rank)
|
||||||
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||||
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||||
|
except:
|
||||||
|
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||||
|
elif lora_type == LORAType.FULL_DIFF:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
|
||||||
|
elif bias_diff and k.endswith(".bias"):
|
||||||
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
|
return output_sd
|
||||||
|
|
||||||
|
class LoraSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||||
|
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
||||||
|
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||||
|
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
||||||
|
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||||
|
if model_diff is None and text_encoder_diff is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
|
||||||
|
output_sd = {}
|
||||||
|
if model_diff is not None:
|
||||||
|
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
if text_encoder_diff is not None:
|
||||||
|
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LoraSave": LoraSave
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"LoraSave": "Extract and Save Lora"
|
||||||
|
}
|
||||||
26
comfy_extras/nodes_mochi.py
Normal file
26
comfy_extras/nodes_mochi.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import nodes
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class EmptyMochiLatentVideo:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent/mochi"
|
||||||
|
|
||||||
|
def generate(self, width, height, length, batch_size=1):
|
||||||
|
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
|
||||||
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
|
||||||
|
}
|
||||||
@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "model_patches/unet"
|
||||||
|
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
|||||||
@@ -101,10 +101,34 @@ class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
|||||||
|
|
||||||
return {"required": arg_dict}
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["pos_embed."] = argument
|
||||||
|
arg_dict["x_embedder."] = argument
|
||||||
|
arg_dict["context_embedder."] = argument
|
||||||
|
arg_dict["y_embedder."] = argument
|
||||||
|
arg_dict["t_embedder."] = argument
|
||||||
|
|
||||||
|
for i in range(38):
|
||||||
|
arg_dict["joint_blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["final_layer."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSD1": ModelMergeSD1,
|
"ModelMergeSD1": ModelMergeSD1,
|
||||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||||
"ModelMergeSDXL": ModelMergeSDXL,
|
"ModelMergeSDXL": ModelMergeSDXL,
|
||||||
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
||||||
"ModelMergeFlux1": ModelMergeFlux1,
|
"ModelMergeFlux1": ModelMergeFlux1,
|
||||||
|
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class PerpNeg:
|
|||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def patch(self, model, empty_conditioning, neg_scale):
|
def patch(self, model, empty_conditioning, neg_scale):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class PhotoMakerLoader:
|
|||||||
CATEGORY = "_for_testing/photomaker"
|
CATEGORY = "_for_testing/photomaker"
|
||||||
|
|
||||||
def load_photomaker_model(self, photomaker_model_name):
|
def load_photomaker_model(self, photomaker_model_name):
|
||||||
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
|
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
|
||||||
photomaker_model = PhotoMakerIDEncoder()
|
photomaker_model = PhotoMakerIDEncoder()
|
||||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||||
if "id_encoder" in data:
|
if "id_encoder" in data:
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ import comfy.sd
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
class TripleCLIPLoader:
|
class TripleCLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@@ -15,9 +15,9 @@ class TripleCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
|
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
|
|||||||
CATEGORY = "latent/sd3"
|
CATEGORY = "latent/sd3"
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3:
|
||||||
@@ -93,15 +93,75 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
|||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
}}
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
|
class SkipLayerGuidanceSD3:
|
||||||
|
'''
|
||||||
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
|
Experimental implementation by Dango233@StabilityAI.
|
||||||
|
'''
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"model": ("MODEL", ),
|
||||||
|
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||||
|
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "skip_guidance"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/guidance"
|
||||||
|
|
||||||
|
|
||||||
|
def skip_guidance(self, model, layers, scale, start_percent, end_percent):
|
||||||
|
if layers == "" or layers == None:
|
||||||
|
return (model, )
|
||||||
|
# check if layer is comma separated integers
|
||||||
|
def skip(args, extra_args):
|
||||||
|
return args
|
||||||
|
|
||||||
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
model = args["model"]
|
||||||
|
cond_pred = args["cond_denoised"]
|
||||||
|
cond = args["cond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
x = args["input"]
|
||||||
|
model_options = args["model_options"].copy()
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||||
|
model_sampling.percent_to_sigma(start_percent)
|
||||||
|
|
||||||
|
sigma_ = sigma[0].item()
|
||||||
|
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||||
|
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||||
|
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||||
|
return cfg_result
|
||||||
|
|
||||||
|
layers = re.findall(r'\d+', layers)
|
||||||
|
layers = [int(i) for i in layers]
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
"TripleCLIPLoader": TripleCLIPLoader,
|
||||||
"EmptySD3LatentImage": EmptySD3LatentImage,
|
"EmptySD3LatentImage": EmptySD3LatentImage,
|
||||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||||
"ControlNetApplySD3": ControlNetApplySD3,
|
"ControlNetApplySD3": ControlNetApplySD3,
|
||||||
|
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
|
|||||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
CATEGORY = "_for_testing/stable_cascade"
|
CATEGORY = "_for_testing/stable_cascade"
|
||||||
|
|
||||||
def generate(self, image, vae):
|
def generate(self, image, vae):
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class TomePatchModel:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "model_patches/unet"
|
||||||
|
|
||||||
def patch(self, model, ratio):
|
def patch(self, model, ratio):
|
||||||
self.u = None
|
self.u = None
|
||||||
|
|||||||
22
comfy_extras/nodes_torch_compile.py
Normal file
22
comfy_extras/nodes_torch_compile.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class TorchCompileModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"backend": (["inductor", "cudagraphs"],),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model, backend):
|
||||||
|
m = model.clone()
|
||||||
|
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TorchCompileModel": TorchCompileModel,
|
||||||
|
}
|
||||||
@@ -25,7 +25,7 @@ class UpscaleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
|
|||||||
CATEGORY = "loaders/video_models"
|
CATEGORY = "loaders/video_models"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (out[0], out[3], out[2])
|
return (out[0], out[3], out[2])
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
|
|||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class SaveImageWebsocket:
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def IS_CHANGED(s, images):
|
def IS_CHANGED(s, images):
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -179,7 +179,13 @@ def merge_result_data(results, obj):
|
|||||||
# merge node execution results
|
# merge node execution results
|
||||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
if is_list:
|
if is_list:
|
||||||
output.append([x for o in results for x in o[i]])
|
value = []
|
||||||
|
for o in results:
|
||||||
|
if isinstance(o[i], ExecutionBlocker):
|
||||||
|
value.append(o[i])
|
||||||
|
else:
|
||||||
|
value.extend(o[i])
|
||||||
|
output.append(value)
|
||||||
else:
|
else:
|
||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -25,11 +25,16 @@ a111:
|
|||||||
|
|
||||||
#comfyui:
|
#comfyui:
|
||||||
# base_path: path/to/comfyui/
|
# base_path: path/to/comfyui/
|
||||||
|
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
|
||||||
|
# #is_default: true
|
||||||
# checkpoints: models/checkpoints/
|
# checkpoints: models/checkpoints/
|
||||||
# clip: models/clip/
|
# clip: models/clip/
|
||||||
# clip_vision: models/clip_vision/
|
# clip_vision: models/clip_vision/
|
||||||
# configs: models/configs/
|
# configs: models/configs/
|
||||||
# controlnet: models/controlnet/
|
# controlnet: models/controlnet/
|
||||||
|
# diffusion_models: |
|
||||||
|
# models/diffusion_models
|
||||||
|
# models/unet
|
||||||
# embeddings: models/embeddings/
|
# embeddings: models/embeddings/
|
||||||
# loras: models/loras/
|
# loras: models/loras/
|
||||||
# upscale_models: models/upscale_models/
|
# upscale_models: models/upscale_models/
|
||||||
|
|||||||
108
folder_paths.py
108
folder_paths.py
@@ -2,7 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import mimetypes
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Set, List, Dict, Tuple, Literal
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
|
|
||||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||||
@@ -16,7 +18,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
|||||||
|
|
||||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
folder_names_and_paths["text_encoders"] = ([os.path.join(models_dir, "text_encoders"), os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||||
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||||
@@ -44,8 +46,43 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
|
|||||||
|
|
||||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
|
||||||
|
class CacheHelper:
|
||||||
|
"""
|
||||||
|
Helper class for managing file list cache data.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
self.active = False
|
||||||
|
|
||||||
|
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
|
||||||
|
if not self.active:
|
||||||
|
return default
|
||||||
|
return self.cache.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||||
|
if self.active:
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.active = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.active = False
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
cache_helper = CacheHelper()
|
||||||
|
|
||||||
|
extension_mimetypes_cache = {
|
||||||
|
"webp" : "image",
|
||||||
|
}
|
||||||
|
|
||||||
def map_legacy(folder_name: str) -> str:
|
def map_legacy(folder_name: str) -> str:
|
||||||
legacy = {"unet": "diffusion_models"}
|
legacy = {"unet": "diffusion_models",
|
||||||
|
"clip": "text_encoders"}
|
||||||
return legacy.get(folder_name, folder_name)
|
return legacy.get(folder_name, folder_name)
|
||||||
|
|
||||||
if not os.path.exists(input_directory):
|
if not os.path.exists(input_directory):
|
||||||
@@ -78,6 +115,13 @@ def get_input_directory() -> str:
|
|||||||
global input_directory
|
global input_directory
|
||||||
return input_directory
|
return input_directory
|
||||||
|
|
||||||
|
def get_user_directory() -> str:
|
||||||
|
return user_directory
|
||||||
|
|
||||||
|
def set_user_directory(user_dir: str) -> None:
|
||||||
|
global user_directory
|
||||||
|
user_directory = user_dir
|
||||||
|
|
||||||
|
|
||||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||||
def get_directory_by_type(type_name: str) -> str | None:
|
def get_directory_by_type(type_name: str) -> str | None:
|
||||||
@@ -89,6 +133,28 @@ def get_directory_by_type(type_name: str) -> str | None:
|
|||||||
return get_input_directory()
|
return get_input_directory()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
|
filter_files_content_types(files, ["image", "audio", "video"])
|
||||||
|
"""
|
||||||
|
global extension_mimetypes_cache
|
||||||
|
result = []
|
||||||
|
for file in files:
|
||||||
|
extension = file.split('.')[-1]
|
||||||
|
if extension not in extension_mimetypes_cache:
|
||||||
|
mime_type, _ = mimetypes.guess_type(file, strict=False)
|
||||||
|
if not mime_type:
|
||||||
|
continue
|
||||||
|
content_type = mime_type.split('/')[0]
|
||||||
|
extension_mimetypes_cache[extension] = content_type
|
||||||
|
else:
|
||||||
|
content_type = extension_mimetypes_cache[extension]
|
||||||
|
|
||||||
|
if content_type in content_types:
|
||||||
|
result.append(file)
|
||||||
|
return result
|
||||||
|
|
||||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
# otherwise use default_path as base_dir
|
# otherwise use default_path as base_dir
|
||||||
@@ -130,11 +196,14 @@ def exists_annotated_filepath(name) -> bool:
|
|||||||
return os.path.exists(filepath)
|
return os.path.exists(filepath)
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
if is_default:
|
||||||
|
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
|
||||||
|
else:
|
||||||
|
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||||
else:
|
else:
|
||||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||||
|
|
||||||
@@ -166,8 +235,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
|
|||||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||||
for file_name in filenames:
|
for file_name in filenames:
|
||||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
try:
|
||||||
result.append(relative_path)
|
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||||
|
result.append(relative_path)
|
||||||
|
except:
|
||||||
|
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||||
|
continue
|
||||||
|
|
||||||
for d in subdirs:
|
for d in subdirs:
|
||||||
path: str = os.path.join(dirpath, d)
|
path: str = os.path.join(dirpath, d)
|
||||||
@@ -200,6 +273,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
|
full_path = get_full_path(folder_name, filename)
|
||||||
|
if full_path is None:
|
||||||
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
@@ -214,6 +295,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
|
|||||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||||
|
|
||||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||||
|
strong_cache = cache_helper.get(folder_name)
|
||||||
|
if strong_cache is not None:
|
||||||
|
return strong_cache
|
||||||
|
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
@@ -242,6 +327,7 @@ def get_filename_list(folder_name: str) -> list[str]:
|
|||||||
out = get_filename_list_(folder_name)
|
out = get_filename_list_(folder_name)
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
filename_list_cache[folder_name] = out
|
filename_list_cache[folder_name] = out
|
||||||
|
cache_helper.set(folder_name, out)
|
||||||
return list(out[0])
|
return list(out[0])
|
||||||
|
|
||||||
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
||||||
@@ -257,9 +343,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||||
input = input.replace("%width%", str(image_width))
|
input = input.replace("%width%", str(image_width))
|
||||||
input = input.replace("%height%", str(image_height))
|
input = input.replace("%height%", str(image_height))
|
||||||
|
now = time.localtime()
|
||||||
|
input = input.replace("%year%", str(now.tm_year))
|
||||||
|
input = input.replace("%month%", str(now.tm_mon).zfill(2))
|
||||||
|
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
||||||
|
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
||||||
|
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
||||||
|
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||||
return input
|
return input
|
||||||
|
|
||||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
if "%" in filename_prefix:
|
||||||
|
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||||
|
|
||||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import folder_paths
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = 512
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
|
|
||||||
def preview_to_image(latent_image):
|
def preview_to_image(latent_image):
|
||||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
@@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self, latent_rgb_factors):
|
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||||
|
self.latent_rgb_factors_bias = None
|
||||||
|
if latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||||
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
if self.latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||||
|
|
||||||
|
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||||
|
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||||
|
|
||||||
return preview_to_image(latent_image)
|
return preview_to_image(latent_image)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
if latent_format.latent_rgb_factors is not None:
|
if latent_format.latent_rgb_factors is not None:
|
||||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
def prepare_callback(model, steps, x0_output_dict=None):
|
def prepare_callback(model, steps, x0_output_dict=None):
|
||||||
|
|||||||
45
main.py
45
main.py
@@ -6,6 +6,10 @@ import importlib.util
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from app.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
setup_logger(log_level=args.verbose)
|
||||||
|
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
@@ -59,6 +63,7 @@ import threading
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import utils.extra_config
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
@@ -81,7 +86,6 @@ if args.windows_standalone_build:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import yaml
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@@ -156,7 +160,10 @@ def prompt_worker(q, server):
|
|||||||
need_gc = False
|
need_gc = False
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
addresses = []
|
||||||
|
for addr in address.split(","):
|
||||||
|
addresses.append((addr, port))
|
||||||
|
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
@@ -176,27 +183,6 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def load_extra_path_config(yaml_path):
|
|
||||||
with open(yaml_path, 'r') as stream:
|
|
||||||
config = yaml.safe_load(stream)
|
|
||||||
for c in config:
|
|
||||||
conf = config[c]
|
|
||||||
if conf is None:
|
|
||||||
continue
|
|
||||||
base_path = None
|
|
||||||
if "base_path" in conf:
|
|
||||||
base_path = conf.pop("base_path")
|
|
||||||
for x in conf:
|
|
||||||
for y in conf[x].split("\n"):
|
|
||||||
if len(y) == 0:
|
|
||||||
continue
|
|
||||||
full_path = y
|
|
||||||
if base_path is not None:
|
|
||||||
full_path = os.path.join(base_path, full_path)
|
|
||||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
|
||||||
folder_paths.add_model_folder_path(x, full_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
@@ -218,11 +204,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
load_extra_path_config(extra_model_paths_config_path)
|
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||||
|
|
||||||
if args.extra_model_paths_config:
|
if args.extra_model_paths_config:
|
||||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||||
load_extra_path_config(config_path)
|
utils.extra_config.load_extra_path_config(config_path)
|
||||||
|
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||||
|
|
||||||
@@ -243,21 +229,30 @@ if __name__ == "__main__":
|
|||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||||
|
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
logging.info(f"Setting input directory to: {input_dir}")
|
logging.info(f"Setting input directory to: {input_dir}")
|
||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
|
if args.user_directory:
|
||||||
|
user_dir = os.path.abspath(args.user_directory)
|
||||||
|
logging.info(f"Setting user directory to: {user_dir}")
|
||||||
|
folder_paths.set_user_directory(user_dir)
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
def startup_server(scheme, address, port):
|
def startup_server(scheme, address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
if os.name == 'nt' and address == '0.0.0.0':
|
if os.name == 'nt' and address == '0.0.0.0':
|
||||||
address = '127.0.0.1'
|
address = '127.0.0.1'
|
||||||
|
if ':' in address:
|
||||||
|
address = "[{}]".format(address)
|
||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# model_manager/__init__.py
|
# model_manager/__init__.py
|
||||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
#NOTE: This was an experiment and WILL BE REMOVED
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -17,6 +18,7 @@ class DownloadStatusType(Enum):
|
|||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadModelStatus():
|
class DownloadModelStatus():
|
||||||
status: str
|
status: str
|
||||||
@@ -29,7 +31,7 @@ class DownloadModelStatus():
|
|||||||
self.progress_percentage = progress_percentage
|
self.progress_percentage = progress_percentage
|
||||||
self.message = message
|
self.message = message
|
||||||
self.already_existed = already_existed
|
self.already_existed = already_existed
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
@@ -38,102 +40,112 @@ class DownloadModelStatus():
|
|||||||
"already_existed": self.already_existed
|
"already_existed": self.already_existed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_directory: str,
|
||||||
|
folder_path: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
"""
|
"""
|
||||||
Download a model file from a given URL into the models directory.
|
Download a model file from a given URL into the models directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||||
model_name (str):
|
model_name (str):
|
||||||
The name of the model file to be downloaded. This will be the filename on disk.
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
model_url (str):
|
model_url (str):
|
||||||
The URL from which to download the model.
|
The URL from which to download the model.
|
||||||
model_sub_directory (str):
|
model_directory (str):
|
||||||
The subdirectory within the main models directory where the model
|
The subdirectory within the main models directory where the model
|
||||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
An asynchronous function to call with progress updates.
|
An asynchronous function to call with progress updates.
|
||||||
|
folder_path (str);
|
||||||
|
Path to which model folder should be used as the root.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DownloadModelStatus: The result of the download operation.
|
DownloadModelStatus: The result of the download operation.
|
||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
"Invalid model subdirectory",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validate_filename(model_name):
|
if not validate_filename(model_name):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
0,
|
0,
|
||||||
"Invalid model name",
|
"Invalid model name",
|
||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
if not model_directory in folder_names_and_paths:
|
||||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not folder_path in get_folder_paths(model_directory):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = create_model_path(model_name, folder_path)
|
||||||
|
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||||
if existing_file:
|
if existing_file:
|
||||||
return existing_file
|
return existing_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logging.info(f"Downloading {model_name} from {model_url}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in downloading model: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
|
||||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||||
os.makedirs(full_model_dir, exist_ok=True)
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
file_path = os.path.join(full_model_dir, model_name)
|
file_path = os.path.join(folder_path, model_name)
|
||||||
|
|
||||||
# Ensure the resulting path is still within the base directory
|
# Ensure the resulting path is still within the base directory
|
||||||
abs_file_path = os.path.abspath(file_path)
|
abs_file_path = os.path.abspath(file_path)
|
||||||
abs_base_dir = os.path.abspath(str(models_base_dir))
|
abs_base_dir = os.path.abspath(folder_path)
|
||||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
relative_path = '/'.join([model_directory, model_name])
|
async def check_file_exists(file_path: str,
|
||||||
return file_path, relative_path
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||||
async def check_file_exists(file_path: str,
|
) -> Optional[DownloadModelStatus]:
|
||||||
model_name: str,
|
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
|
||||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str,
|
|
||||||
interval: float = 1.0) -> DownloadModelStatus:
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
try:
|
try:
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
@@ -144,10 +156,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
with open(file_path, 'wb') as f:
|
temp_file_path = file_path + '.tmp'
|
||||||
|
with open(temp_file_path, 'wb') as f:
|
||||||
chunk_iterator = response.content.iter_chunked(8192)
|
chunk_iterator = response.content.iter_chunked(8192)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -156,58 +169,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
break
|
break
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
downloaded += len(chunk)
|
downloaded += len(chunk)
|
||||||
|
|
||||||
if time.time() - last_update_time >= interval:
|
if time.time() - last_update_time >= interval:
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
|
os.rename(temp_file_path, file_path)
|
||||||
|
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
async def handle_download_error(e: Exception,
|
|
||||||
model_name: str,
|
async def handle_download_error(e: Exception,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
model_name: str,
|
||||||
relative_path: str) -> DownloadModelStatus:
|
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||||
|
) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that the model subdirectory is safe to install into.
|
|
||||||
Must not contain relative paths, nested paths or special characters
|
|
||||||
other than underscores and hyphens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_subdirectory (str): The subdirectory for the specific model type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the subdirectory is safe, False otherwise.
|
|
||||||
"""
|
|
||||||
if len(model_subdirectory) > 50:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def validate_filename(filename: str)-> bool:
|
def validate_filename(filename: str)-> bool:
|
||||||
"""
|
"""
|
||||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): The filename to validate
|
filename (str): The filename to validate
|
||||||
|
|
||||||
|
|||||||
0
models/text_encoders/put_text_encoder_files_here
Normal file
0
models/text_encoders/put_text_encoder_files_here
Normal file
86
nodes.py
86
nodes.py
@@ -281,7 +281,10 @@ class VAEDecode:
|
|||||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
return (vae.decode(samples["samples"]), )
|
images = vae.decode(samples["samples"])
|
||||||
|
if len(images.shape) == 5: #Combine batches
|
||||||
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
|
return (images, )
|
||||||
|
|
||||||
class VAEDecodeTiled:
|
class VAEDecodeTiled:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -511,10 +514,11 @@ class CheckpointLoader:
|
|||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def load_checkpoint(self, config_name, ckpt_name):
|
def load_checkpoint(self, config_name, ckpt_name):
|
||||||
config_path = folder_paths.get_full_path("configs", config_name)
|
config_path = folder_paths.get_full_path("configs", config_name)
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
class CheckpointLoaderSimple:
|
class CheckpointLoaderSimple:
|
||||||
@@ -535,7 +539,7 @@ class CheckpointLoaderSimple:
|
|||||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name):
|
def load_checkpoint(self, ckpt_name):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
@@ -577,7 +581,7 @@ class unCLIPCheckpointLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -624,7 +628,7 @@ class LoraLoader:
|
|||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
return (model, clip)
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||||
lora = None
|
lora = None
|
||||||
if self.loaded_lora is not None:
|
if self.loaded_lora is not None:
|
||||||
if self.loaded_lora[0] == lora_path:
|
if self.loaded_lora[0] == lora_path:
|
||||||
@@ -703,11 +707,11 @@ class VAELoader:
|
|||||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||||
|
|
||||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
|
||||||
for k in enc:
|
for k in enc:
|
||||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||||
|
|
||||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
|
||||||
for k in dec:
|
for k in dec:
|
||||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||||
|
|
||||||
@@ -738,7 +742,7 @@ class VAELoader:
|
|||||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
@@ -754,7 +758,7 @@ class ControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, control_net_name):
|
def load_controlnet(self, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@@ -770,7 +774,7 @@ class DiffControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, model, control_net_name):
|
def load_controlnet(self, model, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@@ -786,6 +790,7 @@ class ControlNetApply:
|
|||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
FUNCTION = "apply_controlnet"
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
DEPRECATED = True
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||||
@@ -815,7 +820,10 @@ class ControlNetApplyAdvanced:
|
|||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
},
|
||||||
|
"optional": {"vae": ("VAE", ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@@ -823,7 +831,7 @@ class ControlNetApplyAdvanced:
|
|||||||
|
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
|
||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (positive, negative)
|
return (positive, negative)
|
||||||
|
|
||||||
@@ -840,7 +848,7 @@ class ControlNetApplyAdvanced:
|
|||||||
if prev_cnet in cnets:
|
if prev_cnet in cnets:
|
||||||
c_net = cnets[prev_cnet]
|
c_net = cnets[prev_cnet]
|
||||||
else:
|
else:
|
||||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
|
||||||
c_net.set_previous_controlnet(prev_cnet)
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
cnets[prev_cnet] = c_net
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
@@ -856,7 +864,7 @@ class UNETLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
@@ -867,18 +875,21 @@ class UNETLoader:
|
|||||||
model_options = {}
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
model_options["dtype"] = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||||
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
model_options["fp8_optimizations"] = True
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
model_options["dtype"] = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
|
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@@ -892,18 +903,20 @@ class CLIPLoader:
|
|||||||
clip_type = comfy.sd.CLIPType.SD3
|
clip_type = comfy.sd.CLIPType.SD3
|
||||||
elif type == "stable_audio":
|
elif type == "stable_audio":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||||
|
elif type == "mochi":
|
||||||
|
clip_type = comfy.sd.CLIPType.MOCHI
|
||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class DualCLIPLoader:
|
class DualCLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("clip"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux"], ),
|
"type": (["sdxl", "sd3", "flux"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
@@ -912,8 +925,8 @@ class DualCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type):
|
def load_clip(self, clip_name1, clip_name2, type):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
if type == "sdxl":
|
if type == "sdxl":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
elif type == "sd3":
|
elif type == "sd3":
|
||||||
@@ -935,7 +948,7 @@ class CLIPVisionLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
||||||
clip_vision = comfy.clip_vision.load(clip_path)
|
clip_vision = comfy.clip_vision.load(clip_path)
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
@@ -965,7 +978,7 @@ class StyleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_style_model(self, style_model_name):
|
def load_style_model(self, style_model_name):
|
||||||
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
|
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
|
||||||
style_model = comfy.sd.load_style_model(style_model_path)
|
style_model = comfy.sd.load_style_model(style_model_path)
|
||||||
return (style_model,)
|
return (style_model,)
|
||||||
|
|
||||||
@@ -1030,7 +1043,7 @@ class GLIGENLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_gligen(self, gligen_name):
|
def load_gligen(self, gligen_name):
|
||||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
|
||||||
gligen = comfy.sd.load_gligen(gligen_path)
|
gligen = comfy.sd.load_gligen(gligen_path)
|
||||||
return (gligen,)
|
return (gligen,)
|
||||||
|
|
||||||
@@ -1171,10 +1184,10 @@ class LatentUpscale:
|
|||||||
|
|
||||||
if width == 0:
|
if width == 0:
|
||||||
height = max(64, height)
|
height = max(64, height)
|
||||||
width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
|
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
|
||||||
elif height == 0:
|
elif height == 0:
|
||||||
width = max(64, width)
|
width = max(64, width)
|
||||||
height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
|
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
|
||||||
else:
|
else:
|
||||||
width = max(64, width)
|
width = max(64, width)
|
||||||
height = max(64, height)
|
height = max(64, height)
|
||||||
@@ -1196,8 +1209,8 @@ class LatentUpscaleBy:
|
|||||||
|
|
||||||
def upscale(self, samples, upscale_method, scale_by):
|
def upscale(self, samples, upscale_method, scale_by):
|
||||||
s = samples.copy()
|
s = samples.copy()
|
||||||
width = round(samples["samples"].shape[3] * scale_by)
|
width = round(samples["samples"].shape[-1] * scale_by)
|
||||||
height = round(samples["samples"].shape[2] * scale_by)
|
height = round(samples["samples"].shape[-2] * scale_by)
|
||||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
@@ -1916,8 +1929,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet",
|
"ControlNetApply": "Apply ControlNet (OLD)",
|
||||||
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
|
"ControlNetApplyAdvanced": "Apply ControlNet",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||||
@@ -1944,6 +1957,12 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImageInvert": "Invert Image",
|
"ImageInvert": "Invert Image",
|
||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
"ImageBatch": "Batch Images",
|
"ImageBatch": "Batch Images",
|
||||||
|
"ImageCrop": "Image Crop",
|
||||||
|
"ImageBlend": "Image Blend",
|
||||||
|
"ImageBlur": "Image Blur",
|
||||||
|
"ImageQuantize": "Image Quantize",
|
||||||
|
"ImageSharpen": "Image Sharpen",
|
||||||
|
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||||
# _for_testing
|
# _for_testing
|
||||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
@@ -2101,6 +2120,9 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_controlnet.py",
|
"nodes_controlnet.py",
|
||||||
"nodes_hunyuan.py",
|
"nodes_hunyuan.py",
|
||||||
"nodes_flux.py",
|
"nodes_flux.py",
|
||||||
|
"nodes_lora_extract.py",
|
||||||
|
"nodes_torch_compile.py",
|
||||||
|
"nodes_mochi.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
@@ -2129,3 +2151,5 @@ def init_extra_nodes(init_custom_nodes=True):
|
|||||||
else:
|
else:
|
||||||
logging.warning("Please do a: pip install -r requirements.txt")
|
logging.warning("Please do a: pip install -r requirements.txt")
|
||||||
logging.warning("")
|
logging.warning("")
|
||||||
|
|
||||||
|
return import_failed
|
||||||
|
|||||||
@@ -79,7 +79,7 @@
|
|||||||
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD1.5\n",
|
"# SD1.5\n",
|
||||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD2\n",
|
"# SD2\n",
|
||||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
|||||||
@@ -38,18 +38,20 @@ def get_images(ws, prompt):
|
|||||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||||
break #Execution is done
|
break #Execution is done
|
||||||
else:
|
else:
|
||||||
|
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
|
||||||
|
# bytesIO = BytesIO(out[8:])
|
||||||
|
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
|
||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = get_history(prompt_id)[prompt_id]
|
history = get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
for node_id in history['outputs']:
|
node_output = history['outputs'][node_id]
|
||||||
node_output = history['outputs'][node_id]
|
images_output = []
|
||||||
if 'images' in node_output:
|
if 'images' in node_output:
|
||||||
images_output = []
|
for image in node_output['images']:
|
||||||
for image in node_output['images']:
|
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
images_output.append(image_data)
|
||||||
images_output.append(image_data)
|
output_images[node_id] = images_output
|
||||||
output_images[node_id] = images_output
|
|
||||||
|
|
||||||
return output_images
|
return output_images
|
||||||
|
|
||||||
@@ -85,7 +87,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
@@ -152,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
@@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
|||||||
158
server.py
158
server.py
@@ -12,6 +12,8 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
|
import socket
|
||||||
|
import ipaddress
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -31,7 +33,6 @@ from model_filemanager import download_model, DownloadModelStatus
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
UNENCODED_PREVIEW_IMAGE = 2
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
@@ -39,9 +40,24 @@ class BinaryEventTypes:
|
|||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||||
logging.warning("send error: {}".format(err))
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
|
def get_comfyui_version():
|
||||||
|
comfyui_version = "unknown"
|
||||||
|
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
try:
|
||||||
|
import pygit2
|
||||||
|
repo = pygit2.Repository(repo_path)
|
||||||
|
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to get ComfyUI version: {e}")
|
||||||
|
return comfyui_version.strip()
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
@@ -66,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
|
|
||||||
return cors_middleware
|
return cors_middleware
|
||||||
|
|
||||||
|
def is_loopback(host):
|
||||||
|
if host is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if ipaddress.ip_address(host).is_loopback:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loopback = False
|
||||||
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
try:
|
||||||
|
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||||
|
for family, _, _, _, sockaddr in r:
|
||||||
|
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||||
|
return loopback
|
||||||
|
else:
|
||||||
|
loopback = True
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return loopback
|
||||||
|
|
||||||
|
|
||||||
|
def create_origin_only_middleware():
|
||||||
|
@web.middleware
|
||||||
|
async def origin_only_middleware(request: web.Request, handler):
|
||||||
|
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||||
|
#in that case the Host and Origin hostnames won't match
|
||||||
|
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||||
|
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||||
|
host = request.headers['Host']
|
||||||
|
origin = request.headers['Origin']
|
||||||
|
host_domain = host.lower()
|
||||||
|
parsed = urllib.parse.urlparse(origin)
|
||||||
|
origin_domain = parsed.netloc.lower()
|
||||||
|
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||||
|
|
||||||
|
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||||
|
loopback = is_loopback(host_domain_parsed.hostname)
|
||||||
|
|
||||||
|
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||||
|
host_domain = host_domain_parsed.hostname
|
||||||
|
if host_domain_parsed.port is None:
|
||||||
|
origin_domain = parsed.hostname
|
||||||
|
|
||||||
|
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||||
|
if host_domain != origin_domain:
|
||||||
|
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
response = web.Response()
|
||||||
|
else:
|
||||||
|
response = await handler(request)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return origin_only_middleware
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
@@ -85,6 +163,8 @@ class PromptServer():
|
|||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||||
|
else:
|
||||||
|
middlewares.append(create_origin_only_middleware())
|
||||||
|
|
||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
@@ -141,6 +221,12 @@ class PromptServer():
|
|||||||
def get_embeddings(self):
|
def get_embeddings(self):
|
||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
|
@routes.get("/models")
|
||||||
|
def list_model_types(request):
|
||||||
|
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||||
|
|
||||||
|
return web.json_response(model_types)
|
||||||
|
|
||||||
@routes.get("/models/{folder}")
|
@routes.get("/models/{folder}")
|
||||||
async def get_models(request):
|
async def get_models(request):
|
||||||
@@ -401,16 +487,25 @@ class PromptServer():
|
|||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def get_queue(request):
|
async def system_stats(request):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
cpu_device = comfy.model_management.torch.device("cpu")
|
||||||
|
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||||||
|
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": os.name,
|
||||||
|
"ram_total": ram_total,
|
||||||
|
"ram_free": ram_free,
|
||||||
|
"comfyui_version": get_comfyui_version(),
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
|
"argv": sys.argv
|
||||||
},
|
},
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
@@ -462,14 +557,15 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
out = {}
|
with folder_paths.cache_helper:
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
out = {}
|
||||||
try:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
out[x] = node_info(x)
|
try:
|
||||||
except Exception as e:
|
out[x] = node_info(x)
|
||||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
except Exception as e:
|
||||||
logging.error(traceback.format_exc())
|
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||||
return web.json_response(out)
|
logging.error(traceback.format_exc())
|
||||||
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/object_info/{node_class}")
|
@routes.get("/object_info/{node_class}")
|
||||||
async def get_object_info_node(request):
|
async def get_object_info_node(request):
|
||||||
@@ -583,18 +679,22 @@ class PromptServer():
|
|||||||
|
|
||||||
# Internal route. Should not be depended upon and is subject to change at any time.
|
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||||
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||||
|
# NOTE: This was an experiment and WILL BE REMOVED
|
||||||
@routes.post("/internal/models/download")
|
@routes.post("/internal/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
await self.send_json("download_progress", status.to_dict())
|
payload = status.to_dict()
|
||||||
|
payload['download_path'] = filename
|
||||||
|
await self.send_json("download_progress", payload)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
model_directory = data.get('model_directory')
|
model_directory = data.get('model_directory')
|
||||||
|
folder_path = data.get('folder_path')
|
||||||
model_filename = data.get('model_filename')
|
model_filename = data.get('model_filename')
|
||||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename:
|
if not url or not model_directory or not model_filename or not folder_path:
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
session = self.client_session
|
session = self.client_session
|
||||||
@@ -602,7 +702,7 @@ class PromptServer():
|
|||||||
logging.error("Client session is not initialized")
|
logging.error("Client session is not initialized")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return web.json_response(task.result().to_dict())
|
return web.json_response(task.result().to_dict())
|
||||||
@@ -719,6 +819,9 @@ class PromptServer():
|
|||||||
await self.send(*msg)
|
await self.send(*msg)
|
||||||
|
|
||||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||||
|
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||||
|
|
||||||
|
async def start_multi_address(self, addresses, call_on_start=None):
|
||||||
runner = web.AppRunner(self.app, access_log=None)
|
runner = web.AppRunner(self.app, access_log=None)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
ssl_ctx = None
|
ssl_ctx = None
|
||||||
@@ -729,17 +832,26 @@ class PromptServer():
|
|||||||
keyfile=args.tls_keyfile)
|
keyfile=args.tls_keyfile)
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
|
|
||||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
logging.info("Starting server\n")
|
||||||
await site.start()
|
for addr in addresses:
|
||||||
|
address = addr[0]
|
||||||
|
port = addr[1]
|
||||||
|
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
self.address = address
|
if not hasattr(self, 'address'):
|
||||||
self.port = port
|
self.address = address #TODO: remove this
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
if ':' in address:
|
||||||
|
address_print = "[{}]".format(address)
|
||||||
|
else:
|
||||||
|
address_print = address
|
||||||
|
|
||||||
|
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||||
|
|
||||||
if verbose:
|
|
||||||
logging.info("Starting server\n")
|
|
||||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
|
||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(scheme, address, port)
|
call_on_start(scheme, self.address, self.port)
|
||||||
|
|
||||||
def add_on_prompt_handler(self, handler):
|
def add_on_prompt_handler(self, handler):
|
||||||
self.on_prompt_handlers.append(handler)
|
self.on_prompt_handlers.append(handler)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
## Install test dependencies
|
## Install test dependencies
|
||||||
|
|
||||||
`pip install -r tests-units/requirements.txt`
|
`pip install -r tests-unit/requirements.txt`
|
||||||
|
|
||||||
## Run tests
|
## Run tests
|
||||||
`pytest tests-units/`
|
`pytest tests-unit/`
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import pytest
|
import pytest
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from app.frontend_management import (
|
from app.frontend_management import (
|
||||||
FrontendManager,
|
FrontendManager,
|
||||||
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
|
|||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
FrontendManager.init_frontend_unsafe(version_string)
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_os_functions():
|
||||||
|
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||||
|
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||||
|
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||||
|
mock_listdir.return_value = [] # Simulate empty directory
|
||||||
|
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_download():
|
||||||
|
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||||
|
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||||
|
# Arrange
|
||||||
|
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||||
|
version_string = 'test-owner/test-repo@1.0.0'
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_makedirs.assert_called_once()
|
||||||
|
mock_download.assert_called_once()
|
||||||
|
mock_listdir.assert_called_once()
|
||||||
|
mock_rmdir.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_parse_version_string():
|
def test_parse_version_string():
|
||||||
version_string = "owner/repo@1.0.0"
|
version_string = "owner/repo@1.0.0"
|
||||||
|
|||||||
66
tests-unit/comfy_test/folder_path_test.py
Normal file
66
tests-unit/comfy_test/folder_path_test.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
||||||
|
# TODO(yoland): clean up this after I get back down
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_directory_by_type():
|
||||||
|
test_dir = "/test/dir"
|
||||||
|
folder_paths.set_output_directory(test_dir)
|
||||||
|
assert folder_paths.get_directory_by_type("output") == test_dir
|
||||||
|
assert folder_paths.get_directory_by_type("invalid") is None
|
||||||
|
|
||||||
|
def test_annotated_filepath():
|
||||||
|
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
|
||||||
|
|
||||||
|
def test_get_annotated_filepath():
|
||||||
|
default_dir = "/default/dir"
|
||||||
|
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||||
|
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||||
|
|
||||||
|
def test_add_model_folder_path():
|
||||||
|
folder_paths.add_model_folder_path("test_folder", "/test/path")
|
||||||
|
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
||||||
|
|
||||||
|
def test_recursive_search(temp_dir):
|
||||||
|
os.makedirs(os.path.join(temp_dir, "subdir"))
|
||||||
|
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
||||||
|
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
|
||||||
|
|
||||||
|
files, dirs = folder_paths.recursive_search(temp_dir)
|
||||||
|
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
||||||
|
assert len(dirs) == 2 # temp_dir and subdir
|
||||||
|
|
||||||
|
def test_filter_files_extensions():
|
||||||
|
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, []) == files
|
||||||
|
|
||||||
|
@patch("folder_paths.recursive_search")
|
||||||
|
@patch("folder_paths.folder_names_and_paths")
|
||||||
|
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||||
|
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
|
||||||
|
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
||||||
|
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||||
|
|
||||||
|
def test_get_save_image_path(temp_dir):
|
||||||
|
with patch("folder_paths.output_directory", temp_dir):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
||||||
|
assert os.path.samefile(full_output_folder, temp_dir)
|
||||||
|
assert filename == "test"
|
||||||
|
assert counter == 1
|
||||||
|
assert subfolder == ""
|
||||||
|
assert filename_prefix == "test"
|
||||||
0
tests-unit/folder_paths_test/__init__.py
Normal file
0
tests-unit/folder_paths_test/__init__.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from folder_paths import filter_files_content_types
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def file_extensions():
|
||||||
|
return {
|
||||||
|
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||||
|
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||||
|
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mock_dir(file_extensions):
|
||||||
|
with tempfile.TemporaryDirectory() as directory:
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
for extension in extensions:
|
||||||
|
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
|
||||||
|
f.write(f"Sample {content_type} file in {extension} format")
|
||||||
|
yield directory
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_correctly(mock_dir, file_extensions):
|
||||||
|
files = os.listdir(mock_dir)
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
for extension in extensions:
|
||||||
|
assert f"sample_{content_type}.{extension}" in filtered_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_uniquely(mock_dir, file_extensions):
|
||||||
|
files = os.listdir(mock_dir)
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
assert len(filtered_files) == len(extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_bad_extensions():
|
||||||
|
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_no_extension():
|
||||||
|
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_no_files():
|
||||||
|
files = []
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
@@ -1,10 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import tempfile
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientResponse
|
from aiohttp import ClientResponse
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
class AsyncIteratorMock:
|
||||||
"""
|
"""
|
||||||
@@ -42,7 +49,7 @@ class ContentMock:
|
|||||||
return AsyncIteratorMock(self.chunks)
|
return AsyncIteratorMock(self.chunks)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_success():
|
async def test_download_model_success(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
@@ -53,15 +60,13 @@ async def test_download_model_success():
|
|||||||
mock_make_request = AsyncMock(return_value=mock_response)
|
mock_make_request = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
# Mock file operations
|
|
||||||
mock_open = MagicMock()
|
|
||||||
mock_file = MagicMock()
|
|
||||||
mock_open.return_value.__enter__.return_value = mock_file
|
|
||||||
time_values = itertools.count(0, 0.1)
|
time_values = itertools.count(0, 0.1)
|
||||||
|
|
||||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
patch('builtins.open', mock_open), \
|
patch('folder_paths.folder_names_and_paths', fake_paths), \
|
||||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
@@ -69,6 +74,7 @@ async def test_download_model_success():
|
|||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'checkpoints',
|
'checkpoints',
|
||||||
|
temp_dir,
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,44 +89,48 @@ async def test_download_model_success():
|
|||||||
|
|
||||||
# Check initial call
|
# Check initial call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check final call
|
# Check final call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify file writing
|
mock_file_path = os.path.join(temp_dir, 'model.sft')
|
||||||
mock_file.write.assert_any_call(b'a' * 500)
|
assert os.path.exists(mock_file_path)
|
||||||
mock_file.write.assert_any_call(b'b' * 300)
|
with open(mock_file_path, 'rb') as mock_file:
|
||||||
mock_file.write.assert_any_call(b'c' * 200)
|
assert mock_file.read() == b''.join(chunks)
|
||||||
|
os.remove(mock_file_path)
|
||||||
|
|
||||||
# Verify request was made
|
# Verify request was made
|
||||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_url_request_failure():
|
async def test_download_model_url_request_failure(temp_dir):
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
mock_response = AsyncMock(spec=ClientResponse)
|
mock_response = AsyncMock(spec=ClientResponse)
|
||||||
mock_response.status = 404 # Simulate a "Not Found" error
|
mock_response.status = 404 # Simulate a "Not Found" error
|
||||||
mock_get = AsyncMock(return_value=mock_response)
|
mock_get = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
# Mock the create_model_path function
|
# Mock the create_model_path function
|
||||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
patch('folder_paths.folder_names_and_paths', fake_paths):
|
||||||
# Call the function
|
# Call the function
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_get,
|
mock_get,
|
||||||
'model.safetensors',
|
'model.safetensors',
|
||||||
'http://example.com/model.safetensors',
|
'http://example.com/model.safetensors',
|
||||||
'mock_directory',
|
'checkpoints',
|
||||||
mock_progress_callback
|
temp_dir,
|
||||||
)
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
# Assert the expected behavior
|
# Assert the expected behavior
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
@@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
# Check that progress_callback was called with the correct arguments
|
# Check that progress_callback was called with the correct arguments
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.PENDING,
|
status=DownloadStatusType.PENDING,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
mock_progress_callback.assert_called_with(
|
mock_progress_callback.assert_called_with(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.ERROR,
|
status=DownloadStatusType.ERROR,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@@ -153,98 +163,125 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_invalid_model_subdirectory():
|
async def test_download_model_invalid_model_subdirectory():
|
||||||
|
|
||||||
mock_make_request = AsyncMock()
|
mock_make_request = AsyncMock()
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_make_request,
|
mock_make_request,
|
||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'../bad_path',
|
'../bad_path',
|
||||||
|
'../bad_path',
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
assert result.message == 'Invalid model subdirectory'
|
assert result.message.startswith('Invalid or unrecognized model directory')
|
||||||
assert result.status == 'error'
|
assert result.status == 'error'
|
||||||
assert result.already_existed is False
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_invalid_folder_path():
|
||||||
|
mock_make_request = AsyncMock()
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'checkpoints',
|
||||||
|
'invalid_path',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message.startswith("Invalid folder path")
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
# For create_model_path function
|
|
||||||
def test_create_model_path(tmp_path, monkeypatch):
|
def test_create_model_path(tmp_path, monkeypatch):
|
||||||
mock_models_dir = tmp_path / "models"
|
model_name = "model.safetensors"
|
||||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
folder_path = os.path.join(tmp_path, "mock_dir")
|
||||||
|
|
||||||
model_name = "test_model.sft"
|
file_path = create_model_path(model_name, folder_path)
|
||||||
model_directory = "test_dir"
|
|
||||||
|
assert file_path == os.path.join(folder_path, "model.safetensors")
|
||||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
|
||||||
|
|
||||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
|
||||||
assert relative_path == f"{model_directory}/{model_name}"
|
|
||||||
assert os.path.exists(os.path.dirname(file_path))
|
assert os.path.exists(os.path.dirname(file_path))
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("../path_traversal.safetensors", folder_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("/etc/some_root_path", folder_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||||
file_path = tmp_path / "existing_model.sft"
|
file_path = tmp_path / "existing_model.sft"
|
||||||
file_path.touch() # Create an empty file
|
file_path.touch() # Create an empty file
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
assert result.message == "existing_model.sft already exists"
|
assert result.message == "existing_model.sft already exists"
|
||||||
assert result.already_existed is True
|
assert result.already_existed is True
|
||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
mock_callback.assert_called_once_with(
|
||||||
"test/existing_model.sft",
|
"existing_model.sft",
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||||
file_path = tmp_path / "non_existing_model.sft"
|
file_path = tmp_path / "non_existing_model.sft"
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_callback.assert_not_called()
|
mock_callback.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_no_content_length():
|
async def test_track_download_progress_no_content_length(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {} # No Content-Length header
|
mock_response.headers = {} # No Content-Length header
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
chunks = [b'a' * 500, b'b' * 500]
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
|
||||||
|
|
||||||
with patch('builtins.open', mock_open):
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
result = await track_download_progress(
|
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
result = await track_download_progress(
|
||||||
mock_callback, 'models/model.sft', interval=0.1
|
mock_response, full_path, 'model.sft',
|
||||||
)
|
mock_callback, interval=0.1
|
||||||
|
)
|
||||||
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
assert os.path.exists(full_path)
|
||||||
|
with open(full_path, 'rb') as f:
|
||||||
|
assert f.read() == b''.join(chunks)
|
||||||
|
os.remove(full_path)
|
||||||
|
|
||||||
# Check that progress was reported even without knowing the total size
|
# Check that progress was reported even without knowing the total size
|
||||||
mock_callback.assert_any_call(
|
mock_callback.assert_any_call(
|
||||||
'models/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_interval():
|
async def test_track_download_progress_interval(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
chunks = [b'a' * 100] * 10
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
@@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
|
|||||||
mock_time = MagicMock()
|
mock_time = MagicMock()
|
||||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||||
|
|
||||||
with patch('builtins.open', mock_open), \
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
patch('time.time', mock_time):
|
|
||||||
await track_download_progress(
|
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
|
||||||
mock_callback, 'models/model.sft', interval=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print out the actual call count and the arguments of each call for debugging
|
with patch('time.time', mock_time):
|
||||||
print(f"mock_callback was called {mock_callback.call_count} times")
|
await track_download_progress(
|
||||||
for i, call in enumerate(mock_callback.call_args_list):
|
mock_response, full_path, 'model.sft',
|
||||||
args, kwargs = call
|
mock_callback, interval=1.0
|
||||||
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
)
|
||||||
|
|
||||||
|
assert os.path.exists(full_path)
|
||||||
|
with open(full_path, 'rb') as f:
|
||||||
|
assert f.read() == b''.join(chunks)
|
||||||
|
os.remove(full_path)
|
||||||
|
|
||||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||||
@@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
|
|||||||
assert last_call[0][1].status == "completed"
|
assert last_call[0][1].status == "completed"
|
||||||
assert last_call[0][1].progress_percentage == 100
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
def test_valid_subdirectory():
|
|
||||||
assert validate_model_subdirectory("valid-model123") is True
|
|
||||||
|
|
||||||
def test_subdirectory_too_long():
|
|
||||||
assert validate_model_subdirectory("a" * 51) is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_double_dots():
|
|
||||||
assert validate_model_subdirectory("model/../unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_slash():
|
|
||||||
assert validate_model_subdirectory("model/unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_special_characters():
|
|
||||||
assert validate_model_subdirectory("model@unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_underscore_and_dash():
|
|
||||||
assert validate_model_subdirectory("valid_model-name") is True
|
|
||||||
|
|
||||||
def test_empty_subdirectory():
|
|
||||||
assert validate_model_subdirectory("") is False
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("filename, expected", [
|
@pytest.mark.parametrize("filename, expected", [
|
||||||
("valid_model.safetensors", True),
|
("valid_model.safetensors", True),
|
||||||
("valid_model.sft", True),
|
("valid_model.sft", True),
|
||||||
|
|||||||
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from aiohttp import web
|
||||||
|
from app.user_manager import UserManager
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
pytestmark = (
|
||||||
|
pytest.mark.asyncio
|
||||||
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_manager(tmp_path):
|
||||||
|
um = UserManager()
|
||||||
|
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||||
|
tmp_path, file
|
||||||
|
)
|
||||||
|
return um
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(user_manager):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
user_manager.add_routes(routes)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == ["file1.txt"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&full_info=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["path"] == "file1.txt"
|
||||||
|
assert "size" in result[0]
|
||||||
|
assert "modified" in result[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == [
|
||||||
|
["subdir/file1.txt", "subdir", "file1.txt"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=")
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||||
|
os_sep = "\\"
|
||||||
|
with patch("os.sep", os_sep):
|
||||||
|
with patch("os.path.sep", os_sep):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0] # Ensure backslash is not present
|
||||||
|
assert result[0] == "subdir/file1.txt"
|
||||||
|
|
||||||
|
# Test with full_info
|
||||||
|
resp = await client.get(
|
||||||
|
"/userdata?dir=test_dir&recurse=true&full_info=true"
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||||
|
assert result[0]["path"] == "subdir/file1.txt"
|
||||||
126
tests-unit/utils/extra_config_test.py
Normal file
126
tests-unit/utils/extra_config_test.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch, mock_open
|
||||||
|
|
||||||
|
from utils.extra_config import load_extra_path_config
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_content():
|
||||||
|
return {
|
||||||
|
'test_config': {
|
||||||
|
'base_path': '~/App/',
|
||||||
|
'checkpoints': 'subfolder1',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expanded_home():
|
||||||
|
return '/home/user'
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yaml_config_with_appdata():
|
||||||
|
return """
|
||||||
|
test_config:
|
||||||
|
base_path: '%APPDATA%/ComfyUI'
|
||||||
|
checkpoints: 'models/checkpoints'
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||||
|
return yaml.safe_load(yaml_config_with_appdata)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expandvars_appdata():
|
||||||
|
mock = Mock()
|
||||||
|
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||||
|
return mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_add_model_folder_path():
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expanduser(mock_expanded_home):
|
||||||
|
def _expanduser(path):
|
||||||
|
if path.startswith('~/'):
|
||||||
|
return os.path.join(mock_expanded_home, path[2:])
|
||||||
|
return path
|
||||||
|
return _expanduser
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_safe_load(mock_yaml_content):
|
||||||
|
return Mock(return_value=mock_yaml_content)
|
||||||
|
|
||||||
|
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||||
|
def test_load_extra_model_paths_expands_userpath(
|
||||||
|
mock_file,
|
||||||
|
monkeypatch,
|
||||||
|
mock_add_model_folder_path,
|
||||||
|
mock_expanduser,
|
||||||
|
mock_yaml_safe_load,
|
||||||
|
mock_expanded_home
|
||||||
|
):
|
||||||
|
# Attach mocks used by load_extra_path_config
|
||||||
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
|
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
|
||||||
|
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
|
||||||
|
|
||||||
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
|
expected_calls = [
|
||||||
|
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
|
# Check if add_model_folder_path was called with the correct arguments
|
||||||
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
|
assert actual_call.args[0] == expected_call[0]
|
||||||
|
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
|
||||||
|
assert actual_call.args[2] == expected_call[2]
|
||||||
|
|
||||||
|
# Check if yaml.safe_load was called
|
||||||
|
mock_yaml_safe_load.assert_called_once()
|
||||||
|
|
||||||
|
# Check if open was called with the correct file path
|
||||||
|
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||||
|
|
||||||
|
@patch('builtins.open', new_callable=mock_open)
|
||||||
|
def test_load_extra_model_paths_expands_appdata(
|
||||||
|
mock_file,
|
||||||
|
monkeypatch,
|
||||||
|
mock_add_model_folder_path,
|
||||||
|
mock_expandvars_appdata,
|
||||||
|
yaml_config_with_appdata,
|
||||||
|
mock_yaml_content_appdata
|
||||||
|
):
|
||||||
|
# Set the mock_file to return yaml with appdata as a variable
|
||||||
|
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
||||||
|
|
||||||
|
# Attach mocks
|
||||||
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
|
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
|
||||||
|
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
|
||||||
|
|
||||||
|
# Mock expanduser to do nothing (since we're not testing it here)
|
||||||
|
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
|
||||||
|
|
||||||
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
|
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||||
|
expected_calls = [
|
||||||
|
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
|
# Check the base path variable was expanded
|
||||||
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
|
assert actual_call.args == expected_call
|
||||||
|
|
||||||
|
# Verify that expandvars was called
|
||||||
|
assert mock_expandvars_appdata.called
|
||||||
@@ -95,17 +95,16 @@ class ComfyClient:
|
|||||||
pass # Probably want to store this off for testing
|
pass # Probably want to store this off for testing
|
||||||
|
|
||||||
history = self.get_history(prompt_id)[prompt_id]
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
for node_id in history['outputs']:
|
node_output = history['outputs'][node_id]
|
||||||
node_output = history['outputs'][node_id]
|
result.outputs[node_id] = node_output
|
||||||
result.outputs[node_id] = node_output
|
images_output = []
|
||||||
if 'images' in node_output:
|
if 'images' in node_output:
|
||||||
images_output = []
|
for image in node_output['images']:
|
||||||
for image in node_output['images']:
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
image_obj = Image.open(BytesIO(image_data))
|
||||||
image_obj = Image.open(BytesIO(image_data))
|
images_output.append(image_obj)
|
||||||
images_output.append(image_obj)
|
node_output['image_objects'] = images_output
|
||||||
node_output['image_objects'] = images_output
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -357,6 +356,25 @@ class TestExecution:
|
|||||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||||
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
||||||
|
|
||||||
|
def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||||
|
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||||
|
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
|
||||||
|
# We have multiple outputs. The first is invalid, but the second is valid
|
||||||
|
g.node("SaveImage", images=mix1.out(0))
|
||||||
|
g.node("SaveImage", images=mix2.out(0))
|
||||||
|
g.remove_node("removeme")
|
||||||
|
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
|
# Add back in the missing node to make sure the error doesn't break the server
|
||||||
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
# Creating the nodes in this specific order previously caused a bug
|
# Creating the nodes in this specific order previously caused a bug
|
||||||
@@ -450,8 +468,8 @@ class TestExecution:
|
|||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
output1 = g.node("PreviewImage", images=input1.out(0))
|
output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
output2 = g.node("PreviewImage", images=input1.out(0))
|
output2 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = client.run(g)
|
||||||
images1 = result.get_images(output1)
|
images1 = result.get_images(output1)
|
||||||
@@ -478,3 +496,29 @@ class TestExecution:
|
|||||||
assert len(images) == 1, "Should have 1 image"
|
assert len(images) == 1, "Should have 1 image"
|
||||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|
||||||
|
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||||
|
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||||
|
# only that one entry in the list is blocked.
|
||||||
|
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
|
||||||
|
int1 = g.node("StubInt", value=1)
|
||||||
|
int2 = g.node("StubInt", value=2)
|
||||||
|
int3 = g.node("StubInt", value=3)
|
||||||
|
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||||
|
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||||
|
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||||
|
|
||||||
|
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||||
|
output = g.node("PreviewImage", images=list_output.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
assert result.did_run(output), "The execution should have run"
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 2, "Should have 2 images"
|
||||||
|
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||||
|
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||||
|
|||||||
@@ -109,15 +109,14 @@ class ComfyClient:
|
|||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = self.get_history(prompt_id)[prompt_id]
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
for node_id in history['outputs']:
|
node_output = history['outputs'][node_id]
|
||||||
node_output = history['outputs'][node_id]
|
images_output = []
|
||||||
if 'images' in node_output:
|
if 'images' in node_output:
|
||||||
images_output = []
|
for image in node_output['images']:
|
||||||
for image in node_output['images']:
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
images_output.append(image_data)
|
||||||
images_output.append(image_data)
|
output_images[node_id] = images_output
|
||||||
output_images[node_id] = images_output
|
|
||||||
|
|
||||||
return output_images
|
return output_images
|
||||||
|
|
||||||
|
|||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user