Compare commits
376 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a616b81c1 | ||
|
|
3bed56bb13 | ||
|
|
4e402b11c6 | ||
|
|
48272448ad | ||
|
|
f7695b5f9e | ||
|
|
452179fe4f | ||
|
|
bf9a90a145 | ||
|
|
c1b92b719d | ||
|
|
cdc3b97dd5 | ||
|
|
8d4e06324f | ||
|
|
57e8bf6a9f | ||
|
|
0ee322ec5f | ||
|
|
79d5ceae6e | ||
|
|
2d5b3e0078 | ||
|
|
8e4118c0de | ||
|
|
3fc6ebcdd7 | ||
|
|
20a560eb97 | ||
|
|
82c5308561 | ||
|
|
26fb2c68e8 | ||
|
|
bf2650a80e | ||
|
|
53646e0f32 | ||
|
|
20879c78f9 | ||
|
|
b666539595 | ||
|
|
95d8713482 | ||
|
|
0d4e29f13f | ||
|
|
497db6212f | ||
|
|
24dc581dc3 | ||
|
|
4c82741b54 | ||
|
|
15c39ea757 | ||
|
|
b7143b74ce | ||
|
|
61196d8857 | ||
|
|
b4526d3fc3 | ||
|
|
3d802710e7 | ||
|
|
7126ecffde | ||
|
|
ab885b33ba | ||
|
|
839ed3368e | ||
|
|
6e8cdcd3cb | ||
|
|
e5c3f4b87f | ||
|
|
bc6be6c11e | ||
|
|
94323a26a7 | ||
|
|
5818f6cf51 | ||
|
|
0b734de449 | ||
|
|
5e16f1d24b | ||
|
|
2fd9c1308a | ||
|
|
8f0009aad0 | ||
|
|
41444b5236 | ||
|
|
772e620e32 | ||
|
|
07f6eeaa13 | ||
|
|
22535d0589 | ||
|
|
898615122f | ||
|
|
156a28786b | ||
|
|
f498d855ba | ||
|
|
b699a15062 | ||
|
|
9cc90ee3eb | ||
|
|
9a0a5d32ee | ||
|
|
d9f90965c8 | ||
|
|
41886af138 | ||
|
|
22a1d7ce78 | ||
|
|
4ac401af2b | ||
|
|
5fb59c8475 | ||
|
|
122c9ca1ce | ||
|
|
3b9a6cf2b1 | ||
|
|
3748e7ef7a | ||
|
|
8ebf2d8831 | ||
|
|
a72d152b0c | ||
|
|
eb476e6ea9 | ||
|
|
2d28b0b479 | ||
|
|
8b275ce5be | ||
|
|
2a18e98ccf | ||
|
|
8a5281006f | ||
|
|
bdeb1c171c | ||
|
|
9c1ed58ef2 | ||
|
|
8b90e50979 | ||
|
|
6ee066a14f | ||
|
|
dd5b57e3d7 | ||
|
|
75a818c720 | ||
|
|
2865f913f7 | ||
|
|
b49616f951 | ||
|
|
5e29e7a488 | ||
|
|
8afb97cd3f | ||
|
|
69694f40b3 | ||
|
|
c49025f01b | ||
|
|
696672905f | ||
|
|
6c9dbde7de | ||
|
|
ee8abf0cff | ||
|
|
fabf449feb | ||
|
|
cc9cf6d1bd | ||
|
|
1c8286a44b | ||
|
|
1af4a47fd1 | ||
|
|
f2aaa0a475 | ||
|
|
daa1565b93 | ||
|
|
09fdb2b269 | ||
|
|
65a8659182 | ||
|
|
770ab200f2 | ||
|
|
954683d0db | ||
|
|
30c0c81351 | ||
|
|
13b0ff8a6f | ||
|
|
c320801187 | ||
|
|
c0b0cfaeec | ||
|
|
669d9e4c67 | ||
|
|
9ee0a6553a | ||
|
|
5cbb01bc2f | ||
|
|
c3ffbae067 | ||
|
|
d605677b33 | ||
|
|
ce759b7db6 | ||
|
|
52810907e2 | ||
|
|
af8cf79a2d | ||
|
|
66b0961a46 | ||
|
|
754597c8a9 | ||
|
|
915fdb5745 | ||
|
|
5a8a48931a | ||
|
|
8ce2a1052c | ||
|
|
f82314fcfc | ||
|
|
0075c6d096 | ||
|
|
83ca891118 | ||
|
|
f9f9faface | ||
|
|
471cd3eace | ||
|
|
a68bbafddb | ||
|
|
73e3a9e676 | ||
|
|
518c0dc2fe | ||
|
|
ce0542e10b | ||
|
|
8473019d40 | ||
|
|
89f15894dd | ||
|
|
67158994a4 | ||
|
|
7390ff3b1e | ||
|
|
0bedfb26af | ||
|
|
f71cfd2687 | ||
|
|
c695c4af7f | ||
|
|
0dbba9f751 | ||
|
|
f584758271 | ||
|
|
95b7cf9bbe | ||
|
|
191a0d56b4 | ||
|
|
3c60ecd7a8 | ||
|
|
7ae6626723 | ||
|
|
6632365e16 | ||
|
|
ad07796777 | ||
|
|
1b80895285 | ||
|
|
5f9d5a244b | ||
|
|
14eba07acd | ||
|
|
4b2f0d9413 | ||
|
|
25eac1d780 | ||
|
|
e38c94228b | ||
|
|
203942c8b2 | ||
|
|
3c72c89a52 | ||
|
|
614377abd6 | ||
|
|
8dfa0cc552 | ||
|
|
e5ecdfdd2d | ||
|
|
7d29fbf74b | ||
|
|
2c641e64ad | ||
|
|
7d2467e830 | ||
|
|
6f021d8aa0 | ||
|
|
d854ed0bcf | ||
|
|
abcd006b8c | ||
|
|
d985d1d7dc | ||
|
|
d1cdf51e1b | ||
|
|
b4626ab93e | ||
|
|
a9e459c2a4 | ||
|
|
3bb4dec720 | ||
|
|
8733191563 | ||
|
|
83b01f960a | ||
|
|
d72e871cfa | ||
|
|
037c3159b6 | ||
|
|
bdd4a22a2e | ||
|
|
fdf37566ef | ||
|
|
08c8968482 | ||
|
|
479a427a48 | ||
|
|
3a0eeee320 | ||
|
|
447da7ea86 | ||
|
|
9c41bc8d10 | ||
|
|
6ad0ddbae4 | ||
|
|
a55142f904 | ||
|
|
5718ef69bb | ||
|
|
13ecf10a92 | ||
|
|
7a415f47a9 | ||
|
|
89fa2fca24 | ||
|
|
364b69e931 | ||
|
|
dc96a1ae19 | ||
|
|
2d810b081e | ||
|
|
9f7e9f0547 | ||
|
|
a355f38ecc | ||
|
|
38c69080c7 | ||
|
|
70a708d726 | ||
|
|
e7d4782736 | ||
|
|
3326bdfd4e | ||
|
|
68bb885d22 | ||
|
|
ad66f7c7d8 | ||
|
|
de8e8e3b0d | ||
|
|
a1e71cfad1 | ||
|
|
0bfc7cc998 | ||
|
|
7183fd1665 | ||
|
|
254838f23c | ||
|
|
0b7dfa986d | ||
|
|
d514bb38ee | ||
|
|
0849c80e2a | ||
|
|
56e8f5e4fd | ||
|
|
e813abbb2c | ||
|
|
5e68a4ce67 | ||
|
|
ca08597670 | ||
|
|
f48e390032 | ||
|
|
369a6dd2c4 | ||
|
|
b3ce8fb9fd | ||
|
|
cf80d28689 | ||
|
|
6fb44c4b7c | ||
|
|
d2247c1e61 | ||
|
|
cb12ad7049 | ||
|
|
f6b7194f64 | ||
|
|
7c6eb4fb29 | ||
|
|
b962db9952 | ||
|
|
d0b7ab88ba | ||
|
|
405b529545 | ||
|
|
9d720187f1 | ||
|
|
d247bc5a9c | ||
|
|
9f4daca9d9 | ||
|
|
b5d0f2a908 | ||
|
|
e760bf5c40 | ||
|
|
36c83cdbba | ||
|
|
81778a7feb | ||
|
|
bc94662b31 | ||
|
|
9fa8faa44a | ||
|
|
9a7444e39f | ||
|
|
54fca4a218 | ||
|
|
cd4955367e | ||
|
|
8354203d95 | ||
|
|
e0b41243b4 | ||
|
|
619263d4a6 | ||
|
|
e3b0402bb7 | ||
|
|
967867d48c | ||
|
|
cbaac71bf5 | ||
|
|
3ab3516e46 | ||
|
|
9c5fca75f4 | ||
|
|
a5da4d0b3e | ||
|
|
32a60a7bac | ||
|
|
bb52934ba4 | ||
|
|
8aabd7c8c0 | ||
|
|
a09b29ca11 | ||
|
|
9bfee68773 | ||
|
|
ea77750759 | ||
|
|
c27ebeb1c2 | ||
|
|
0c7c98a965 | ||
|
|
dc2eb75b85 | ||
|
|
fa34efe3bd | ||
|
|
5cbaa9e07c | ||
|
|
c7427375ee | ||
|
|
22d1241a50 | ||
|
|
f04229b84d | ||
|
|
f067ad15d1 | ||
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 | ||
|
|
f1c2301697 | ||
|
|
8d31a6632f | ||
|
|
b643eae08b | ||
|
|
baa6b4dc36 | ||
|
|
d4aeefc297 | ||
|
|
587e7ca654 | ||
|
|
c90459eba0 | ||
|
|
04278afb10 | ||
|
|
935ae153e1 | ||
|
|
e91662e784 | ||
|
|
63fafaef45 | ||
|
|
ec28cd9136 | ||
|
|
6eb5d64522 | ||
|
|
10a79e9898 | ||
|
|
ea3f39bd69 | ||
|
|
b33cd61070 | ||
|
|
34eda0f853 | ||
|
|
d31e226650 | ||
|
|
b79fd7d92c | ||
|
|
38c22e631a | ||
|
|
6bbdcd28ae | ||
|
|
ab130001a8 | ||
|
|
ca4b8f30e0 | ||
|
|
70b84058c1 | ||
|
|
2ca8f6e23d | ||
|
|
7985ff88b9 | ||
|
|
c6812947e9 | ||
|
|
9230f65823 | ||
|
|
6ab1e6fd4a | ||
|
|
07dcbc3a3e | ||
|
|
8ae23d8e80 | ||
|
|
7df42b9a23 | ||
|
|
5d8bbb7281 | ||
|
|
2c1d2375d6 | ||
|
|
64ccb3c7e3 | ||
|
|
9465b23432 | ||
|
|
bb4416dd5b | ||
|
|
c0b0da264b | ||
|
|
c26ca27207 | ||
|
|
7c6bb84016 | ||
|
|
c54d3ed5e6 | ||
|
|
c7ee4b37a1 | ||
|
|
7b70b266d8 | ||
|
|
8f60d093ba | ||
|
|
dafbe321d2 | ||
|
|
5f84ea63e8 | ||
|
|
843a7ff70c | ||
|
|
a60620dcea | ||
|
|
015f73dc49 | ||
|
|
904bf58e7d | ||
|
|
5f50263088 | ||
|
|
5e806f555d | ||
|
|
f07e5bb522 | ||
|
|
03ec517afb | ||
|
|
f257fc999f | ||
|
|
bb50e69839 | ||
|
|
510f3438c1 | ||
|
|
ea63b1c092 | ||
|
|
9953f22fce | ||
|
|
d1a6bd6845 | ||
|
|
83dbac28eb | ||
|
|
538cb068bc | ||
|
|
1b3eee672c | ||
|
|
5a69f84c3c | ||
|
|
9eee470244 | ||
|
|
045377ea89 | ||
|
|
4d341b78e8 | ||
|
|
6138f92084 | ||
|
|
be0726c1ed | ||
|
|
766ae119a8 | ||
|
|
fc90ceb6ba | ||
|
|
4506ddc86a | ||
|
|
20ace7c853 | ||
|
|
b29b3b86c5 | ||
|
|
22ec02afc0 | ||
|
|
39f114c44b | ||
|
|
6730f3e1a3 | ||
|
|
73332160c8 | ||
|
|
2622c55aff | ||
|
|
1beb348ee2 | ||
|
|
9aa39e743c | ||
|
|
d31df04c8a | ||
|
|
e68763f40c | ||
|
|
310ad09258 | ||
|
|
4f7a3cb6fb | ||
|
|
bb222ceddb | ||
|
|
14af129c55 | ||
|
|
fca42836f2 | ||
|
|
858d51f91a | ||
|
|
cd5017c1c9 | ||
|
|
83f343146a | ||
|
|
b021cf67c7 | ||
|
|
1770fc77ed | ||
|
|
05a9f3faa1 | ||
|
|
86c5970ac0 | ||
|
|
bfc214f434 | ||
|
|
3f5939add6 | ||
|
|
5960f946a9 | ||
|
|
5cfe38f41c | ||
|
|
0f9c2a7822 | ||
|
|
153d0a8142 | ||
|
|
ab4dd19b91 | ||
|
|
f1d6cef71c | ||
|
|
33fb282d5c | ||
|
|
50bf66e5c4 | ||
|
|
e60e19b175 | ||
|
|
a5af64d3ce | ||
|
|
3e52e0364c | ||
|
|
34608de2e9 | ||
|
|
39fb74c5bd | ||
|
|
74e124f4d7 | ||
|
|
a562c17e8a | ||
|
|
5942c17d55 | ||
|
|
c032b11e07 | ||
|
|
b8ffb2937f | ||
|
|
ce37c11164 | ||
|
|
b5c3906b38 | ||
|
|
5d43e75e5b | ||
|
|
517f4a94e4 | ||
|
|
52a471c5c7 | ||
|
|
ad76574cb8 | ||
|
|
9acfe4df41 | ||
|
|
9829b013ea | ||
|
|
5c69cde037 | ||
|
|
e9589d6d92 | ||
|
|
0d82a798a5 | ||
|
|
925fff26fd |
@@ -75,6 +75,25 @@ else:
|
||||
print("pulling latest changes")
|
||||
pull(repo)
|
||||
|
||||
if "--stable" in sys.argv:
|
||||
def latest_tag(repo):
|
||||
versions = []
|
||||
for k in repo.references:
|
||||
try:
|
||||
prefix = "refs/tags/v"
|
||||
if k.startswith(prefix):
|
||||
version = list(map(int, k[len(prefix):].split(".")))
|
||||
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
|
||||
except:
|
||||
pass
|
||||
versions.sort()
|
||||
if len(versions) > 0:
|
||||
return versions[-1][1]
|
||||
return None
|
||||
latest_tag = latest_tag(repo)
|
||||
if latest_tag is not None:
|
||||
repo.checkout(latest_tag)
|
||||
|
||||
print("Done!")
|
||||
|
||||
self_update = True
|
||||
@@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
|
||||
shutil.copy(repo_req_path, req_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
|
||||
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
|
||||
|
||||
try:
|
||||
if not file_size(stable_update_script_to) > 10:
|
||||
shutil.copy(stable_update_script, stable_update_script_to)
|
||||
except:
|
||||
pass
|
||||
|
||||
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
@@ -0,0 +1,8 @@
|
||||
@echo off
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
|
||||
if exist update_new.py (
|
||||
move /y update_new.py update.py
|
||||
echo Running updater again since it got updated.
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
|
||||
)
|
||||
if "%~1"=="" pause
|
||||
@@ -14,7 +14,7 @@ run_cpu.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
|
||||
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
@@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
||||
pause
|
||||
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
/web/assets/** linguist-generated
|
||||
/web/** linguist-vendored
|
||||
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: ComfyUI Frontend Issues
|
||||
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
|
||||
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
|
||||
- name: ComfyUI Matrix Space
|
||||
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
||||
|
||||
18
.github/workflows/pullrequest-ci-run.yml
vendored
18
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
@@ -35,3 +35,19 @@ jobs:
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
use_prior_commit: 'true'
|
||||
comment:
|
||||
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
||||
})
|
||||
|
||||
6
.github/workflows/stable-release.yml
vendored
6
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
default: "124"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
|
||||
|
||||
jobs:
|
||||
|
||||
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: 'Close stale issues'
|
||||
on:
|
||||
schedule:
|
||||
# Run daily at 430 am PT
|
||||
- cron: '30 11 * * *'
|
||||
permissions:
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||
days-before-stale: 30
|
||||
days-before-close: 7
|
||||
stale-issue-label: 'Stale'
|
||||
only-labels: 'User Support'
|
||||
exempt-all-assignees: true
|
||||
exempt-all-milestones: true
|
||||
76
.github/workflows/test-browser.yml
vendored
76
.github/workflows/test-browser.yml
vendored
@@ -1,76 +0,0 @@
|
||||
# This is a temporary action during frontend TS migration.
|
||||
# This file should be removed after TS migration is completed.
|
||||
# The browser test is here to ensure TS repo is working the same way as the
|
||||
# current JS code.
|
||||
# If you are adding UI feature, please sync your changes to the TS repo:
|
||||
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
|
||||
name: Playwright Browser Tests CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- name: Checkout ComfyUI_frontend
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "huchenlei/ComfyUI_frontend"
|
||||
path: "ComfyUI_frontend"
|
||||
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: lts/*
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install wait-for-it
|
||||
working-directory: ComfyUI
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||
working-directory: ComfyUI
|
||||
- name: Install ComfyUI_frontend dependencies
|
||||
run: |
|
||||
npm ci
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Install Playwright Browsers
|
||||
run: npx playwright install --with-deps
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Run Playwright tests
|
||||
run: npx playwright test
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
working-directory: ComfyUI
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: ComfyUI_frontend/playwright-report/
|
||||
retention-days: 30
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: console-output
|
||||
path: ComfyUI/console_output.log
|
||||
retention-days: 30
|
||||
4
.github/workflows/test-ci.yml
vendored
4
.github/workflows/test-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
|
||||
45
.github/workflows/test-launch.yml
vendored
Normal file
45
.github/workflows/test-launch.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: Test server launches without errors
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install wait-for-it
|
||||
working-directory: ComfyUI
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
working-directory: ComfyUI
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: console-output
|
||||
path: ComfyUI/console_output.log
|
||||
retention-days: 30
|
||||
@@ -1,29 +1,29 @@
|
||||
name: Tests CI
|
||||
name: Unit Tests
|
||||
|
||||
on: [push, pull_request]
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
node-version: 18
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
- name: Run Tests
|
||||
run: |
|
||||
npm ci
|
||||
npm run test:generate
|
||||
npm test -- --verbose
|
||||
working-directory: ./tests-ui
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
@@ -12,7 +12,7 @@ on:
|
||||
description: 'extra dependencies'
|
||||
required: false
|
||||
type: string
|
||||
default: "\"numpy<2\""
|
||||
default: ""
|
||||
cu:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
@@ -23,13 +23,13 @@ on:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -67,6 +67,7 @@ jobs:
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||
|
||||
echo "call update_comfyui.bat nopause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
|
||||
@@ -13,13 +13,13 @@ on:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -12,10 +12,12 @@ extra_model_paths.yaml
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
!/web/extensions/core/
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
|
||||
153
README.md
153
README.md
@@ -1,8 +1,35 @@
|
||||
ComfyUI
|
||||
=======
|
||||
The most powerful and modular stable diffusion GUI and backend.
|
||||
-----------
|
||||

|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular diffusion model GUI and backend.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||
[![Matrix][matrix-shield]][matrix-url]
|
||||
<br>
|
||||
[![][github-release-shield]][github-release-link]
|
||||
[![][github-release-date-shield]][github-release-link]
|
||||
[![][github-downloads-shield]][github-downloads-link]
|
||||
[![][github-downloads-latest-shield]][github-downloads-link]
|
||||
|
||||
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
|
||||
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
|
||||
[website-url]: https://www.comfy.org/
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
|
||||
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@@ -12,7 +39,9 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@@ -46,32 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
| Keybind | Explanation |
|
||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Alt + `+` | Canvas Zoom in |
|
||||
| Alt + `-` | Canvas Zoom out |
|
||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| `Ctrl` + `Enter` | Queue up current graph for generation |
|
||||
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
|
||||
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
|
||||
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
|
||||
| `Ctrl` + `S` | Save workflow |
|
||||
| `Ctrl` + `O` | Load workflow |
|
||||
| `Ctrl` + `A` | Select all nodes |
|
||||
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
|
||||
| `Ctrl` + `M` | Mute/unmute selected nodes |
|
||||
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| `Delete`/`Backspace` | Delete selected nodes |
|
||||
| `Ctrl` + `Backspace` | Delete the current graph |
|
||||
| `Space` | Move the canvas around when held and moving the cursor |
|
||||
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
|
||||
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
|
||||
| `Ctrl` + `D` | Load default graph |
|
||||
| `Alt` + `+` | Canvas Zoom in |
|
||||
| `Alt` + `-` | Canvas Zoom out |
|
||||
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| `P` | Pin/Unpin selected nodes |
|
||||
| `Ctrl` + `G` | Group selected nodes |
|
||||
| `Q` | Toggle visibility of the queue |
|
||||
| `H` | Toggle visibility of history |
|
||||
| `R` | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| `Shift` + Drag | Move multiple wires at once |
|
||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
`Ctrl` can also be replaced with `Cmd` instead for macOS users
|
||||
|
||||
# Installing
|
||||
|
||||
@@ -95,6 +129,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
@@ -105,17 +141,17 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
|
||||
@@ -177,6 +213,14 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
|
||||
|
||||
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
Only parts of the graph that have an output with all the correct inputs will be executed.
|
||||
@@ -200,7 +244,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
||||
|
||||
Use ```--preview-method auto``` to enable previews.
|
||||
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||
|
||||
## How to use TLS/SSL?
|
||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||
@@ -216,6 +260,47 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
|
||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
|
||||
## Frontend Development
|
||||
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||
|
||||
### Reporting Issues and Requesting Features
|
||||
|
||||
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
|
||||
|
||||
### Using the Latest Frontend
|
||||
|
||||
The new frontend is now the default for ComfyUI. However, please note:
|
||||
|
||||
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||
2. Daily releases are available in the separate frontend repository.
|
||||
|
||||
To use the most up-to-date frontend version:
|
||||
|
||||
1. For the latest daily release, launch ComfyUI with this command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@latest
|
||||
```
|
||||
|
||||
2. For a specific version, replace `latest` with the desired version number:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||
```
|
||||
|
||||
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||
```
|
||||
|
||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||
|
||||
# QA
|
||||
|
||||
### Which GPU should I buy for this?
|
||||
|
||||
0
api_server/__init__.py
Normal file
0
api_server/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
3
api_server/routes/internal/README.md
Normal file
3
api_server/routes/internal/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# ComfyUI Internal Routes
|
||||
|
||||
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
||||
0
api_server/routes/internal/__init__.py
Normal file
0
api_server/routes/internal/__init__.py
Normal file
75
api_server/routes/internal/internal_routes.py
Normal file
75
api_server/routes/internal/internal_routes.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
|
||||
class InternalRoutes:
|
||||
'''
|
||||
The top level web router for internal routes: /internal/*
|
||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||
Check README.md for more information.
|
||||
'''
|
||||
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
async def list_files(request):
|
||||
directory_key = request.query.get('directory', '')
|
||||
try:
|
||||
file_list = self.file_service.list_files(directory_key)
|
||||
return web.json_response({"files": file_list})
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
|
||||
@self.routes.get('/logs/raw')
|
||||
async def get_logs(request):
|
||||
self.terminal_service.update_size()
|
||||
return web.json_response({
|
||||
"entries": list(app.logger.get_logs()),
|
||||
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
||||
})
|
||||
|
||||
@self.routes.patch('/logs/subscribe')
|
||||
async def subscribe_logs(request):
|
||||
json_data = await request.json()
|
||||
client_id = json_data["clientId"]
|
||||
enabled = json_data["enabled"]
|
||||
if enabled:
|
||||
self.terminal_service.subscribe(client_id)
|
||||
else:
|
||||
self.terminal_service.unsubscribe(client_id)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@self.routes.get('/folder_paths')
|
||||
async def get_folder_paths(request):
|
||||
response = {}
|
||||
for key in folder_names_and_paths:
|
||||
response[key] = folder_names_and_paths[key][0]
|
||||
return web.json_response(response)
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
self.setup_routes()
|
||||
self._app.add_routes(self.routes)
|
||||
return self._app
|
||||
0
api_server/services/__init__.py
Normal file
0
api_server/services/__init__.py
Normal file
13
api_server/services/file_service.py
Normal file
13
api_server/services/file_service.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Dict, List, Optional
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||
|
||||
class FileService:
|
||||
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||
|
||||
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||
if directory_key not in self.allowed_directories:
|
||||
raise ValueError("Invalid directory key")
|
||||
directory_path: str = self.allowed_directories[directory_key]
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
||||
60
api_server/services/terminal_service.py
Normal file
60
api_server/services/terminal_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from app.logger import on_flush
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class TerminalService:
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
self.cols = None
|
||||
self.rows = None
|
||||
self.subscriptions = set()
|
||||
on_flush(self.send_messages)
|
||||
|
||||
def get_terminal_size(self):
|
||||
try:
|
||||
size = os.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
try:
|
||||
size = shutil.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
return (80, 24) # fallback to 80x24
|
||||
|
||||
def update_size(self):
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
return {"cols": self.cols, "rows": self.rows}
|
||||
|
||||
return None
|
||||
|
||||
def subscribe(self, client_id):
|
||||
self.subscriptions.add(client_id)
|
||||
|
||||
def unsubscribe(self, client_id):
|
||||
self.subscriptions.discard(client_id)
|
||||
|
||||
def send_messages(self, entries):
|
||||
if not len(entries) or not len(self.subscriptions):
|
||||
return
|
||||
|
||||
new_size = self.update_size()
|
||||
|
||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||
if client_id not in self.server.sockets:
|
||||
# Automatically unsub if the socket has disconnected
|
||||
self.unsubscribe(client_id)
|
||||
continue
|
||||
|
||||
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
||||
42
api_server/utils/file_operations.py
Normal file
42
api_server/utils/file_operations.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import List, Union, TypedDict, Literal
|
||||
from typing_extensions import TypeGuard
|
||||
class FileInfo(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
type: Literal["file"]
|
||||
size: int
|
||||
|
||||
class DirectoryInfo(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
type: Literal["directory"]
|
||||
|
||||
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
||||
|
||||
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
||||
return item["type"] == "file"
|
||||
|
||||
class FileSystemOperations:
|
||||
@staticmethod
|
||||
def walk_directory(directory: str) -> List[FileSystemItem]:
|
||||
file_list: List[FileSystemItem] = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for name in files:
|
||||
file_path = os.path.join(root, name)
|
||||
relative_path = os.path.relpath(file_path, directory)
|
||||
file_list.append({
|
||||
"name": name,
|
||||
"path": relative_path,
|
||||
"type": "file",
|
||||
"size": os.path.getsize(file_path)
|
||||
})
|
||||
for name in dirs:
|
||||
dir_path = os.path.join(root, name)
|
||||
relative_path = os.path.relpath(dir_path, directory)
|
||||
file_list.append({
|
||||
"name": name,
|
||||
"path": relative_path,
|
||||
"type": "directory"
|
||||
})
|
||||
return file_list
|
||||
@@ -8,7 +8,7 @@ import zipfile
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
@@ -132,12 +132,13 @@ class FrontendManager:
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
@@ -150,7 +151,16 @@ class FrontendManager:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
provider = FrontEndProvider(repo_owner, repo_name)
|
||||
|
||||
if version.startswith("v"):
|
||||
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||
if os.path.exists(expected_path):
|
||||
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||
return expected_path
|
||||
|
||||
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||
|
||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
semantic_version = release["tag_name"].lstrip("v")
|
||||
@@ -158,15 +168,21 @@ class FrontendManager:
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||
)
|
||||
if not os.path.exists(web_root):
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
try:
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
finally:
|
||||
# Clean up the directory if it is empty, i.e. the download failed
|
||||
if not os.listdir(web_root):
|
||||
os.rmdir(web_root)
|
||||
|
||||
return web_root
|
||||
|
||||
@classmethod
|
||||
|
||||
73
app/logger.py
Normal file
73
app/logger.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
logs = None
|
||||
stdout_interceptor = None
|
||||
stderr_interceptor = None
|
||||
|
||||
|
||||
class LogInterceptor(io.TextIOWrapper):
|
||||
def __init__(self, stream, *args, **kwargs):
|
||||
buffer = stream.buffer
|
||||
encoding = stream.encoding
|
||||
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
||||
self._lock = threading.Lock()
|
||||
self._flush_callbacks = []
|
||||
self._logs_since_flush = []
|
||||
|
||||
def write(self, data):
|
||||
entry = {"t": datetime.now().isoformat(), "m": data}
|
||||
with self._lock:
|
||||
self._logs_since_flush.append(entry)
|
||||
|
||||
# Simple handling for cr to overwrite the last output if it isnt a full line
|
||||
# else logs just get full of progress messages
|
||||
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
||||
logs.pop()
|
||||
logs.append(entry)
|
||||
super().write(data)
|
||||
|
||||
def flush(self):
|
||||
super().flush()
|
||||
for cb in self._flush_callbacks:
|
||||
cb(self._logs_since_flush)
|
||||
self._logs_since_flush = []
|
||||
|
||||
def on_flush(self, callback):
|
||||
self._flush_callbacks.append(callback)
|
||||
|
||||
|
||||
def get_logs():
|
||||
return logs
|
||||
|
||||
|
||||
def on_flush(callback):
|
||||
if stdout_interceptor is not None:
|
||||
stdout_interceptor.on_flush(callback)
|
||||
if stderr_interceptor is not None:
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
|
||||
# Override output streams and log to buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
|
||||
global stdout_interceptor
|
||||
global stderr_interceptor
|
||||
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
||||
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
||||
|
||||
# Setup default global logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(log_level)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(stream_handler)
|
||||
@@ -1,38 +1,58 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
from folder_paths import user_directory
|
||||
import folder_paths
|
||||
from .app_settings import AppSettings
|
||||
from typing import TypedDict
|
||||
|
||||
default_user = "default"
|
||||
users_file = os.path.join(user_directory, "users.json")
|
||||
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
}
|
||||
|
||||
|
||||
class UserManager():
|
||||
def __init__(self):
|
||||
global user_directory
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
os.mkdir(user_directory)
|
||||
os.makedirs(user_directory, exist_ok=True)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(users_file):
|
||||
with open(users_file) as f:
|
||||
if os.path.isfile(self.get_users_file()):
|
||||
with open(self.get_users_file()) as f:
|
||||
self.users = json.load(f)
|
||||
else:
|
||||
self.users = {}
|
||||
else:
|
||||
self.users = {"default": "default"}
|
||||
|
||||
def get_users_file(self):
|
||||
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||
|
||||
def get_request_user_id(self, request):
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
@@ -44,7 +64,7 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
global user_directory
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
@@ -59,6 +79,10 @@ class UserManager():
|
||||
return None
|
||||
|
||||
if file is not None:
|
||||
# Check if filename is url encoded
|
||||
if "%" in file:
|
||||
file = parse.unquote(file)
|
||||
|
||||
# prevent leaving /{type}/{user}
|
||||
path = os.path.abspath(os.path.join(user_root, file))
|
||||
if os.path.commonpath((user_root, path)) != user_root:
|
||||
@@ -80,8 +104,7 @@ class UserManager():
|
||||
|
||||
self.users[user_id] = name
|
||||
|
||||
global users_file
|
||||
with open(users_file, "w") as f:
|
||||
with open(self.get_users_file(), "w") as f:
|
||||
json.dump(self.users, f)
|
||||
|
||||
return user_id
|
||||
@@ -112,25 +135,65 @@ class UserManager():
|
||||
|
||||
@routes.get("/userdata")
|
||||
async def listuserdata(request):
|
||||
"""
|
||||
List user data files in a specified directory.
|
||||
|
||||
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||
full file information, and path splitting.
|
||||
|
||||
Query Parameters:
|
||||
- dir (required): The directory to list files from.
|
||||
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||
|
||||
Returns:
|
||||
- 400: If 'dir' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 404: If the requested directory does not exist.
|
||||
- 200: JSON response with the list of files or file information.
|
||||
|
||||
The response format depends on the query parameters:
|
||||
- Default: List of relative file paths.
|
||||
- full_info=true: List of dictionaries with file details.
|
||||
- split=true (and full_info=false): List of lists, each containing path components.
|
||||
"""
|
||||
directory = request.rel_url.query.get('dir', '')
|
||||
if not directory:
|
||||
return web.Response(status=400)
|
||||
|
||||
return web.Response(status=400, text="Directory not provided")
|
||||
|
||||
path = self.get_request_user_filepath(request, directory)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
return web.Response(status=403, text="Invalid directory")
|
||||
|
||||
if not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
return web.Response(status=404, text="Directory not found")
|
||||
|
||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||
results = glob.glob(os.path.join(
|
||||
glob.escape(path), '**/*'), recursive=recurse)
|
||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
||||
|
||||
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
if split_path:
|
||||
results = [[x] + x.split(os.sep) for x in results]
|
||||
|
||||
# Use different patterns based on whether we're recursing or not
|
||||
if recurse:
|
||||
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||
else:
|
||||
pattern = os.path.join(glob.escape(path), '*')
|
||||
|
||||
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
|
||||
if full_info:
|
||||
return get_file_info(full_path, path)
|
||||
|
||||
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
|
||||
if split_path:
|
||||
return [rel_path] + rel_path.split('/')
|
||||
|
||||
return rel_path
|
||||
|
||||
results = [
|
||||
process_full_path(full_path)
|
||||
for full_path in glob.glob(pattern, recursive=recurse)
|
||||
if os.path.isfile(full_path)
|
||||
]
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@@ -138,14 +201,14 @@ class UserManager():
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
return web.Response(status=400)
|
||||
|
||||
|
||||
path = self.get_request_user_filepath(request, file)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
|
||||
if check_exists and not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
|
||||
return path
|
||||
|
||||
@routes.get("/userdata/{file}")
|
||||
@@ -153,25 +216,56 @@ class UserManager():
|
||||
path = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
|
||||
return web.FileResponse(path)
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
"""
|
||||
Upload or update a user data file.
|
||||
|
||||
This endpoint handles file uploads to a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Path Parameters:
|
||||
- file: The target file path (URL encoded if necessary).
|
||||
|
||||
Returns:
|
||||
- 400: If 'file' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 409: If overwrite=false and the file already exists.
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
|
||||
The request body should contain the raw file content to be written.
|
||||
"""
|
||||
path = get_user_data_path(request)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409)
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(path, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(path, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
@routes.delete("/userdata/{file}")
|
||||
@@ -181,25 +275,56 @@ class UserManager():
|
||||
return path
|
||||
|
||||
os.remove(path)
|
||||
|
||||
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.post("/userdata/{file}/move/{dest}")
|
||||
async def move_userdata(request):
|
||||
"""
|
||||
Move or rename a user data file.
|
||||
|
||||
This endpoint handles moving or renaming files within a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Path Parameters:
|
||||
- file: The source file path (URL encoded if necessary)
|
||||
- dest: The destination file path (URL encoded if necessary)
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Returns:
|
||||
- 400: If either 'file' or 'dest' parameter is missing
|
||||
- 403: If either requested path is not allowed
|
||||
- 404: If the source file does not exist
|
||||
- 409: If overwrite=false and the destination file already exists
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
"""
|
||||
source = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(source, str):
|
||||
return source
|
||||
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409)
|
||||
|
||||
print(f"moving '{source}' -> '{dest}'")
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
logging.info(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(dest, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(dest, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
122
comfy/cldm/dit_embedder.py
Normal file
122
comfy/cldm/dit_embedder.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||
|
||||
|
||||
class ControlNetEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
adm_in_channels: int,
|
||||
num_layers: int,
|
||||
main_model_double: int,
|
||||
double_y_emb: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.main_model_double = main_model_double
|
||||
self.dtype = dtype
|
||||
self.hidden_size = num_attention_heads * attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=pos_embed_max_size is None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.double_y_emb = double_y_emb
|
||||
if self.double_y_emb:
|
||||
self.orig_y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
self.y_embedder = VectorEmbedder(
|
||||
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
DismantledBlock(
|
||||
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
)
|
||||
|
||||
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
||||
# TODO double check this logic when 8b
|
||||
self.use_y_embedder = True
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.pos_embed_input = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
hint = None,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
x_shape = list(x.shape)
|
||||
x = self.x_embedder(x)
|
||||
if not self.double_y_emb:
|
||||
h = (x_shape[-2] + 1) // self.patch_size
|
||||
w = (x_shape[-1] + 1) // self.patch_size
|
||||
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
||||
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
if self.double_y_emb:
|
||||
y = self.orig_y_embedder(y)
|
||||
y = self.y_embedder(y)
|
||||
c = c + y
|
||||
|
||||
x = x + self.pos_embed_input(hint)
|
||||
|
||||
block_out = ()
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
||||
for i in range(len(self.transformer_blocks)):
|
||||
out = self.transformer_blocks[i](x, c)
|
||||
if not self.double_y_emb:
|
||||
x = out
|
||||
block_out += (self.controlnet_blocks[i](out),) * repeat
|
||||
|
||||
return {"output": block_out}
|
||||
@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks = None,
|
||||
control_latent_channels = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
for _ in range(len(self.joint_blocks)):
|
||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||
|
||||
if control_latent_channels is None:
|
||||
control_latent_channels = self.in_channels
|
||||
|
||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||
None,
|
||||
self.patch_size,
|
||||
self.in_channels,
|
||||
control_latent_channels,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=False,
|
||||
|
||||
@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||
@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
||||
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
|
||||
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
|
||||
@@ -92,6 +94,12 @@ class LatentPreviewMethod(enum.Enum):
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
@@ -112,10 +120,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
||||
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
@@ -126,7 +138,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
@@ -161,6 +173,8 @@ parser.add_argument(
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
@@ -171,10 +185,3 @@ if args.windows_standalone_build:
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
import logging
|
||||
logging_level = logging.INFO
|
||||
if args.verbose:
|
||||
logging_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||
|
||||
@@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
|
||||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
@@ -88,10 +89,11 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
num_positions = config_dict["max_position_embeddings"]
|
||||
self.eos_token_id = config_dict["eos_token_id"]
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
@@ -123,7 +125,6 @@ class CLIPTextModel(torch.nn.Module):
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
if model_type == "siglip_vision_model":
|
||||
self.class_embedding = None
|
||||
patch_bias = True
|
||||
else:
|
||||
num_patches = num_patches + 1
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
patch_bias = False
|
||||
|
||||
self.patch_embedding = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
bias=patch_bias,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
if self.class_embedding is not None:
|
||||
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
||||
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
@@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.output_layernorm = False
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
||||
@@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
|
||||
x = self.pre_layrnorm(x)
|
||||
#TODO: attention_mask?
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
if self.output_layernorm:
|
||||
x = self.post_layernorm(x)
|
||||
pooled_output = x
|
||||
else:
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
if "projection_dim" in config_dict:
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
|
||||
@@ -16,13 +16,18 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
@@ -35,6 +40,8 @@ class ClipVisionModel():
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -49,9 +56,9 @@ class ClipVisionModel():
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
@@ -109,8 +118,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
t = sd.pop(k)
|
||||
del t
|
||||
sd.pop(k)
|
||||
return clip
|
||||
|
||||
def load(ckpt_path):
|
||||
|
||||
13
comfy/clip_vision_siglip_384.json
Normal file
13
comfy/clip_vision_siglip_384.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 384,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
43
comfy/comfy_types/README.md
Normal file
43
comfy/comfy_types/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Comfy Typing
|
||||
## Type hinting for ComfyUI Node development
|
||||
|
||||
This module provides type hinting and concrete convenience types for node developers.
|
||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||
|
||||
```python
|
||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {"required": {}}
|
||||
```
|
||||
|
||||
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
|
||||
|
||||
# Types
|
||||
A few primary types are documented below. More complete information is available via the docstrings on each type.
|
||||
|
||||
## `IO`
|
||||
|
||||
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
|
||||
|
||||
- `ANY`: `"*"`
|
||||
- `NUMBER`: `"FLOAT,INT"`
|
||||
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
|
||||
|
||||
## `ComfyNodeABC`
|
||||
|
||||
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
|
||||
|
||||
### Type hinting for `INPUT_TYPES`
|
||||
|
||||

|
||||
|
||||
### `INPUT_TYPES` return dict
|
||||
|
||||

|
||||
|
||||
### Options for individual inputs
|
||||
|
||||

|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
|
||||
class UnetApplyFunction(Protocol):
|
||||
@@ -30,3 +31,15 @@ class UnetParams(TypedDict):
|
||||
|
||||
|
||||
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"UnetWrapperFunction",
|
||||
UnetApplyConds.__name__,
|
||||
UnetParams.__name__,
|
||||
UnetApplyFunction.__name__,
|
||||
IO.__name__,
|
||||
InputTypeDict.__name__,
|
||||
ComfyNodeABC.__name__,
|
||||
CheckLazyMixin.__name__,
|
||||
]
|
||||
28
comfy/comfy_types/examples/example_nodes.py
Normal file
28
comfy/comfy_types/examples/example_nodes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
"""An example node that just adds 1 to an input integer.
|
||||
|
||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
||||
* Not intended for use in ComfyUI.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
CATEGORY = "examples"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"input_int": (IO.INT, {"defaultInput": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("input_plus_one",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, input_int: int):
|
||||
return (input_int + 1,)
|
||||
BIN
comfy/comfy_types/examples/input_options.png
Normal file
BIN
comfy/comfy_types/examples/input_options.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
BIN
comfy/comfy_types/examples/input_types.png
Normal file
BIN
comfy/comfy_types/examples/input_types.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
274
comfy/comfy_types/node_typing.py
Normal file
274
comfy/comfy_types/node_typing.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class IO(StrEnum):
|
||||
"""Node input/output data types.
|
||||
|
||||
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
|
||||
"""
|
||||
|
||||
STRING = "STRING"
|
||||
IMAGE = "IMAGE"
|
||||
MASK = "MASK"
|
||||
LATENT = "LATENT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
GUIDER = "GUIDER"
|
||||
NOISE = "NOISE"
|
||||
CLIP = "CLIP"
|
||||
CONTROL_NET = "CONTROL_NET"
|
||||
VAE = "VAE"
|
||||
MODEL = "MODEL"
|
||||
CLIP_VISION = "CLIP_VISION"
|
||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||
STYLE_MODEL = "STYLE_MODEL"
|
||||
GLIGEN = "GLIGEN"
|
||||
UPSCALE_MODEL = "UPSCALE_MODEL"
|
||||
AUDIO = "AUDIO"
|
||||
WEBCAM = "WEBCAM"
|
||||
POINT = "POINT"
|
||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
|
||||
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
|
||||
"""
|
||||
NUMBER = "FLOAT,INT"
|
||||
"""A float or an int - could be either"""
|
||||
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
|
||||
"""Could be any of: string, float, int, or bool"""
|
||||
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
|
||||
"""
|
||||
|
||||
default: bool | str | float | int | list | tuple
|
||||
"""The default value of the widget"""
|
||||
defaultInput: bool
|
||||
"""Defaults to an input slot rather than a widget"""
|
||||
forceInput: bool
|
||||
"""`defaultInput` and also don't allow converting to a widget"""
|
||||
lazy: bool
|
||||
"""Declares that this input uses lazy evaluation"""
|
||||
rawLink: bool
|
||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||
tooltip: str
|
||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: float
|
||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||
max: float
|
||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||
step: float
|
||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||
round: float
|
||||
"""Floats are rounded by this value (``FLOAT``)"""
|
||||
# class InputTypeBoolean(InputTypeOptions):
|
||||
# default: bool
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
multiline: bool
|
||||
"""Use a multiline text box (``STRING``)"""
|
||||
placeholder: str
|
||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||
# Deprecated:
|
||||
# defaultVal: str
|
||||
dynamicPrompts: bool
|
||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||
|
||||
node_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
unique_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
prompt: Literal["PROMPT"]
|
||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||
dynprompt: Literal["DYNPROMPT"]
|
||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||
|
||||
|
||||
class InputTypeDict(TypedDict):
|
||||
"""Provides type hinting for node INPUT_TYPES.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
|
||||
"""
|
||||
|
||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes all inputs that must be connected for the node to execute."""
|
||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes inputs which do not need to be connected."""
|
||||
hidden: HiddenInputTypeDict
|
||||
"""Offers advanced functionality and server-client communication.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
"""
|
||||
|
||||
|
||||
class ComfyNodeABC(ABC):
|
||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
|
||||
"""
|
||||
|
||||
DESCRIPTION: str
|
||||
"""Node description, shown as a tooltip when hovering over the node.
|
||||
|
||||
Usage::
|
||||
|
||||
# Explicitly define the description
|
||||
DESCRIPTION = "Example description here."
|
||||
|
||||
# Use the docstring of the node class.
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
"""
|
||||
CATEGORY: str
|
||||
"""The category of the node, as per the "Add Node" menu.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
|
||||
"""
|
||||
EXPERIMENTAL: bool
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
"""Defines node inputs.
|
||||
|
||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
|
||||
"""
|
||||
return {"required": {}}
|
||||
|
||||
OUTPUT_NODE: bool
|
||||
"""Flags this node as an output node, causing any inputs it requires to be executed.
|
||||
|
||||
If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
From the docs:
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
|
||||
"""
|
||||
INPUT_IS_LIST: bool
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
|
||||
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||
|
||||
From the docs:
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
|
||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||
|
||||
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
|
||||
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
|
||||
|
||||
From the docs:
|
||||
|
||||
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
|
||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
||||
specifying which outputs which should be so treated.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
"""A tuple representing the outputs of this node.
|
||||
|
||||
Usage::
|
||||
|
||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
|
||||
"""
|
||||
|
||||
|
||||
class CheckLazyMixin:
|
||||
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
|
||||
|
||||
def check_lazy_status(self, **kwargs) -> list[str]:
|
||||
"""Returns a list of input names that should be evaluated.
|
||||
|
||||
This basic mixin impl. requires all inputs.
|
||||
|
||||
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
|
||||
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
|
||||
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
return need
|
||||
@@ -34,6 +34,12 @@ import comfy.t2i_adapter.adapter
|
||||
import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
@@ -58,7 +64,7 @@ class StrengthType(Enum):
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
@@ -70,20 +76,26 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
self.extra_concat_orig = []
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
if self.latent_format is not None:
|
||||
if vae is None:
|
||||
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||
self.vae = vae
|
||||
self.extra_concat_orig = extra_concat.copy()
|
||||
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||
return self
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
@@ -98,9 +110,9 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
@@ -108,6 +120,14 @@ class ControlBase:
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
out.append(self.extra_hooks)
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
@@ -121,6 +141,10 @@ class ControlBase:
|
||||
c.vae = self.vae
|
||||
c.extra_conds = self.extra_conds.copy()
|
||||
c.strength_type = self.strength_type
|
||||
c.concat_mask = self.concat_mask
|
||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
|
||||
c.preprocess_image = self.preprocess_image
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@@ -146,7 +170,7 @@ class ControlBase:
|
||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||
x *= (self.strength ** float(len(control_output) - i))
|
||||
|
||||
if x.dtype != output_dtype:
|
||||
if output_dtype is not None and x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
@@ -173,8 +197,8 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
||||
super().__init__(device)
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
|
||||
super().__init__()
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
@@ -187,11 +211,13 @@ class ControlNet(ControlBase):
|
||||
self.latent_format = latent_format
|
||||
self.extra_conds += extra_conds
|
||||
self.strength_type = strength_type
|
||||
self.concat_mask = concat_mask
|
||||
self.preprocess_image = preprocess_image
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
@@ -204,7 +230,6 @@ class ControlNet(ControlBase):
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
@@ -212,14 +237,26 @@ class ControlNet(ControlBase):
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
if self.latent_format is not None:
|
||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||
if len(self.extra_concat_orig) > 0:
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = c.to(self.cond_hint.device)
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
@@ -234,7 +271,7 @@ class ControlNet(ControlBase):
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
return self.control_merge(control, control_prev, output_dtype=None)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
@@ -318,8 +355,8 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
||||
ControlBase.__init__(self, device)
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
ControlBase.__init__(self)
|
||||
self.control_weights = control_weights
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.extra_conds += ["y"]
|
||||
@@ -375,21 +412,28 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def controlnet_config(sd):
|
||||
def controlnet_config(sd, model_options={}):
|
||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
unet_dtype = model_options.get("dtype", None)
|
||||
if unet_dtype is None:
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
controlnet_config = model_config.unet_config
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
@@ -401,25 +445,108 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
|
||||
def load_controlnet_mmdit(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
concat_mask = False
|
||||
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||
if control_latent_channels == 17: #inpaint controlnet
|
||||
concat_mask = True
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_controlnet_sd35(sd, model_options={}):
|
||||
control_type = -1
|
||||
if "control_type" in sd:
|
||||
control_type = round(sd.pop("control_type").item())
|
||||
|
||||
# blur_cnet = control_type == 0
|
||||
canny_cnet = control_type == 1
|
||||
depth_cnet = control_type == 2
|
||||
|
||||
new_sd = {}
|
||||
for k in comfy.utils.MMDIT_MAP_BASIC:
|
||||
if k[1] in sd:
|
||||
new_sd[k[0]] = sd.pop(k[1])
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
sd = new_sd
|
||||
|
||||
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
|
||||
depth = y_emb_shape[0] // 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
head_dim = hidden_size // num_heads
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
|
||||
patch_size=2,
|
||||
in_chans=16,
|
||||
num_layers=num_blocks,
|
||||
main_model_double=depth,
|
||||
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
|
||||
attention_head_dim=head_dim,
|
||||
num_attention_heads=num_heads,
|
||||
adm_in_channels=2048,
|
||||
device=offload_device,
|
||||
dtype=unet_dtype,
|
||||
operations=operations)
|
||||
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
preprocess_image = lambda a: a
|
||||
if canny_cnet:
|
||||
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
|
||||
elif depth_cnet:
|
||||
preprocess_image = lambda a: 1.0 - a
|
||||
|
||||
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||
return control
|
||||
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||
|
||||
latent_format = comfy.latent_formats.SDXL()
|
||||
@@ -427,13 +554,49 @@ def load_controlnet_hunyuandit(controlnet_data):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
num_union_modes = 0
|
||||
union_cnet = "controlnet_mode_embedder.weight"
|
||||
if union_cnet in new_sd:
|
||||
num_union_modes = new_sd[union_cnet].shape[0]
|
||||
|
||||
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||
concat_mask = False
|
||||
if control_latent_channels == 17:
|
||||
concat_mask = True
|
||||
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Flux()
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
|
||||
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
controlnet_data = state_dict
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
return load_controlnet_hunyuandit(controlnet_data)
|
||||
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
return ControlLora(controlnet_data, model_options=model_options)
|
||||
|
||||
controlnet_config = None
|
||||
supported_inference_dtypes = None
|
||||
@@ -488,8 +651,18 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
@@ -501,26 +674,38 @@ def load_controlnet(ckpt_path, model=None):
|
||||
elif key in controlnet_data:
|
||||
prefix = ""
|
||||
else:
|
||||
net = load_t2i_adapter(controlnet_data)
|
||||
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||
if net is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
logging.error("error could not detect control model type.")
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
controlnet_config = model_config.unet_config
|
||||
|
||||
unet_dtype = model_options.get("dtype", None)
|
||||
if unet_dtype is None:
|
||||
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||
|
||||
if supported_inference_dtypes is None:
|
||||
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
if supported_inference_dtypes is None:
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
else:
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||
|
||||
controlnet_config["operations"] = operations
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
@@ -554,22 +739,32 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||
if "global_average_pooling" not in model_options:
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
model_options["global_average_pooling"] = True
|
||||
|
||||
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||
if cnet is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
return cnet
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||
super().__init__(device)
|
||||
super().__init__()
|
||||
self.t2i_model = t2i_model
|
||||
self.channels_in = channels_in
|
||||
self.control_input = None
|
||||
self.compression_ratio = compression_ratio
|
||||
self.upscale_algorithm = upscale_algorithm
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.device = device
|
||||
|
||||
def scale_image_to(self, width, height):
|
||||
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||
@@ -577,10 +772,10 @@ class T2IAdapter(ControlBase):
|
||||
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
||||
return width, height
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
@@ -617,7 +812,7 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
if text_encoder2_path is not None:
|
||||
text_encoder_paths.append(text_encoder2_path)
|
||||
|
||||
unet = comfy.sd.load_unet(unet_path)
|
||||
unet = comfy.sd.load_diffusion_model(unet_path)
|
||||
|
||||
clip = None
|
||||
if output_clip:
|
||||
|
||||
@@ -16,7 +16,7 @@ class NoiseScheduleVP:
|
||||
continuous_beta_0=0.1,
|
||||
continuous_beta_1=20.,
|
||||
):
|
||||
"""Create a wrapper class for the forward SDE (VP type).
|
||||
r"""Create a wrapper class for the forward SDE (VP type).
|
||||
|
||||
***
|
||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||
|
||||
67
comfy/float.py
Normal file
67
comfy/float.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
|
||||
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||
|
||||
#Not 100% sure about this
|
||||
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||
elif dtype == torch.float8_e5m2:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
||||
else:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
x = x.half()
|
||||
sign = torch.sign(x)
|
||||
abs_x = x.abs()
|
||||
sign = torch.where(abs_x == 0, 0, sign)
|
||||
|
||||
# Combine exponent calculation and clamping
|
||||
exponent = torch.clamp(
|
||||
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||
0, 2**EXPONENT_BITS - 1
|
||||
)
|
||||
|
||||
# Combine mantissa calculation and rounding
|
||||
normal_mask = ~(exponent == 0)
|
||||
|
||||
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||
|
||||
sign *= torch.where(
|
||||
normal_mask,
|
||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
|
||||
inf = torch.finfo(dtype)
|
||||
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
||||
return sign
|
||||
|
||||
|
||||
|
||||
def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float32:
|
||||
return value.to(dtype=torch.float32)
|
||||
if dtype == torch.float16:
|
||||
return value.to(dtype=torch.float16)
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
for i in range(0, value.shape[0], slice_size):
|
||||
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
690
comfy/hooks.py
Normal file
690
comfy/hooks.py
Normal file
@@ -0,0 +1,690 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import enum
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.sd import CLIP
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
class EnumHookMode(enum.Enum):
|
||||
MinVram = "minvram"
|
||||
MaxSpeed = "maxspeed"
|
||||
|
||||
class EnumHookType(enum.Enum):
|
||||
Weight = "weight"
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Callbacks = "callbacks"
|
||||
Wrappers = "wrappers"
|
||||
SetInjections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
Model = "model"
|
||||
Clip = "clip"
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
|
||||
# NOTE: this is an example of how the should_register function should look
|
||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return True
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||
self.hook_type = hook_type
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
self.hook_id = hook_id
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
self.reset()
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: Hook = subtype()
|
||||
c.hook_type = self.hook_type
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return self.custom_should_register(self, model, model_options, target, registered)
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: 'Hook'):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hook_ref)
|
||||
|
||||
class WeightHook(Hook):
|
||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||
super().__init__(hook_type=EnumHookType.Weight)
|
||||
self.weights: dict = None
|
||||
self.weights_clip: dict = None
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
return self._strength_model * self.strength
|
||||
|
||||
@property
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
weights = None
|
||||
if target == EnumWeightTarget.Model:
|
||||
strength = self._strength_model
|
||||
else:
|
||||
strength = self._strength_clip
|
||||
|
||||
if self.need_weight_init:
|
||||
key_map = {}
|
||||
if target == EnumWeightTarget.Model:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Model:
|
||||
weights = self.weights
|
||||
else:
|
||||
weights = self.weights_clip
|
||||
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WeightHook = super().clone(subtype)
|
||||
c.weights = self.weights
|
||||
c.weights_clip = self.weights_clip
|
||||
c.need_weight_init = self.need_weight_init
|
||||
c._strength_model = self._strength_model
|
||||
c._strength_clip = self._strength_clip
|
||||
return c
|
||||
|
||||
class PatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.Patch)
|
||||
self.patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: PatchHook = super().clone(subtype)
|
||||
c.patches = self.patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: ObjectPatchHook = super().clone(subtype)
|
||||
c.object_patches = self.object_patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class AddModelsHook(Hook):
|
||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
self.models = models
|
||||
self.append_when_same = True
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddModelsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
return True
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: SetInjectionsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.injections = self.injections.copy() if self.injections else self.injections
|
||||
return c
|
||||
|
||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
# TODO: add functionality
|
||||
pass
|
||||
|
||||
class HookGroup:
|
||||
def __init__(self):
|
||||
self.hooks: list[Hook] = []
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
|
||||
def contains(self, hook: Hook):
|
||||
return hook in self.hooks
|
||||
|
||||
def clone(self):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone_and_combine(self, other: 'HookGroup'):
|
||||
c = self.clone()
|
||||
if other is not None:
|
||||
for hook in other.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
hook_kf = hook_kf.clone()
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
def get_dict_repr(self):
|
||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||
for hook in self.hooks:
|
||||
with_type = d.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
# hooks should not have their schedules in a list of tuples
|
||||
all_ranges: list[tuple[float, float]] = []
|
||||
for range_kfs in scheduled_hooks.values():
|
||||
for t_range, keyframe in range_kfs:
|
||||
all_ranges.append(t_range)
|
||||
# turn list of ranges into boundaries
|
||||
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
|
||||
boundaries_set.add(0.0)
|
||||
boundaries = sorted(boundaries_set)
|
||||
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
|
||||
# with real ranges defined, give appropriate hooks w/ keyframes for each range
|
||||
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
|
||||
for t_range in real_ranges:
|
||||
hooks_schedule = []
|
||||
for hook, val in scheduled_hooks.items():
|
||||
keyframe = None
|
||||
# check if is a keyframe that works for the current t_range
|
||||
for stored_range, stored_kf in val:
|
||||
# if stored start is less than current end, then fits - give it assigned keyframe
|
||||
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
|
||||
keyframe = stored_kf
|
||||
break
|
||||
hooks_schedule.append((hook, keyframe))
|
||||
scheduled_keyframes.append((t_range, hooks_schedule))
|
||||
return scheduled_keyframes
|
||||
|
||||
def reset(self):
|
||||
for hook in self.hooks:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
actual.append(group)
|
||||
if len(actual) < require_count:
|
||||
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
|
||||
# if no hooks, then return None
|
||||
if len(actual) == 0:
|
||||
return None
|
||||
# if only 1 hook, just return itself without cloning
|
||||
elif len(actual) == 1:
|
||||
return actual[0]
|
||||
final_hook: HookGroup = None
|
||||
for hook in actual:
|
||||
if final_hook is None:
|
||||
final_hook = hook.clone()
|
||||
else:
|
||||
final_hook = final_hook.clone_and_combine(hook)
|
||||
return final_hook
|
||||
|
||||
|
||||
class HookKeyframe:
|
||||
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
|
||||
self.strength = strength
|
||||
# scheduling
|
||||
self.start_percent = float(start_percent)
|
||||
self.start_t = 999999999.9
|
||||
self.guarantee_steps = guarantee_steps
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframe(strength=self.strength,
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
c.start_t = self.start_t
|
||||
return c
|
||||
|
||||
class HookKeyframeGroup:
|
||||
def __init__(self):
|
||||
self.keyframes: list[HookKeyframe] = []
|
||||
self._current_keyframe: HookKeyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self._curr_t = -1.
|
||||
|
||||
# properties shadow those of HookWeightsKeyframe
|
||||
@property
|
||||
def strength(self):
|
||||
if self._current_keyframe is not None:
|
||||
return self._current_keyframe.strength
|
||||
return 1.0
|
||||
|
||||
def reset(self):
|
||||
self._current_keyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self.curr_t = -1.
|
||||
self._set_first_as_current()
|
||||
|
||||
def add(self, keyframe: HookKeyframe):
|
||||
# add to end of list, then sort
|
||||
self.keyframes.append(keyframe)
|
||||
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
|
||||
self._set_first_as_current()
|
||||
|
||||
def _set_first_as_current(self):
|
||||
if len(self.keyframes) > 0:
|
||||
self._current_keyframe = self.keyframes[0]
|
||||
else:
|
||||
self._current_keyframe = None
|
||||
|
||||
def has_index(self, index: int):
|
||||
return index >= 0 and index < len(self.keyframes)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.keyframes) == 0
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframeGroup()
|
||||
for keyframe in self.keyframes:
|
||||
c.keyframes.append(keyframe.clone())
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
||||
if self.is_empty():
|
||||
return False
|
||||
if curr_t == self._curr_t:
|
||||
return False
|
||||
prev_index = self._current_index
|
||||
prev_strength = self._current_strength
|
||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
||||
# if has next index, loop through and see if need to switch
|
||||
if self.has_index(self._current_index+1):
|
||||
for i in range(self._current_index+1, len(self.keyframes)):
|
||||
eval_c = self.keyframes[i]
|
||||
# check if start_t is greater or equal to curr_t
|
||||
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
|
||||
if eval_c.start_t >= curr_t:
|
||||
self._current_index = i
|
||||
self._current_strength = eval_c.strength
|
||||
self._current_keyframe = eval_c
|
||||
self._current_used_steps = 0
|
||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||
if self._current_keyframe.guarantee_steps > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
self._curr_t = curr_t
|
||||
# return True if keyframe changed, False if no change
|
||||
return prev_index != self._current_index and prev_strength != self._current_strength
|
||||
|
||||
|
||||
class InterpolationMethod:
|
||||
LINEAR = "linear"
|
||||
EASE_IN = "ease_in"
|
||||
EASE_OUT = "ease_out"
|
||||
EASE_IN_OUT = "ease_in_out"
|
||||
|
||||
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
||||
|
||||
@classmethod
|
||||
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
||||
diff = num_to - num_from
|
||||
if method == cls.LINEAR:
|
||||
weights = torch.linspace(num_from, num_to, length)
|
||||
elif method == cls.EASE_IN:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * np.power(index, 2) + num_from
|
||||
elif method == cls.EASE_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
||||
elif method == cls.EASE_IN_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
||||
else:
|
||||
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
||||
if reverse:
|
||||
weights = weights.flip(dims=(0,))
|
||||
return weights
|
||||
|
||||
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
if not objects:
|
||||
return objects
|
||||
elif len(objects) <= 1:
|
||||
return [x for x in objects]
|
||||
# now that we know we have to sort, do it following these rules:
|
||||
# a) if objects have same value of attribute, maintain their relative order
|
||||
# b) perform sorting of the groups of objects with same attributes
|
||||
unique_attrs = {}
|
||||
for o in objects:
|
||||
val_attr = getattr(o, attr)
|
||||
attr_list: list = unique_attrs.get(val_attr, list())
|
||||
attr_list.append(o)
|
||||
if val_attr not in unique_attrs:
|
||||
unique_attrs[val_attr] = attr_list
|
||||
# now that we have the unique attr values grouped together in relative order, sort them by key
|
||||
sorted_attrs = dict(sorted(unique_attrs.items()))
|
||||
# now flatten out the dict into a list to return
|
||||
sorted_list = []
|
||||
for object_list in sorted_attrs.values():
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
hook.weights = lora
|
||||
return hook_group
|
||||
|
||||
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
patches_model = None
|
||||
patches_clip = None
|
||||
if weights_model is not None:
|
||||
patches_model = {}
|
||||
for key in weights_model:
|
||||
patches_model[key] = ("model_as_lora", (weights_model[key],))
|
||||
if weights_clip is not None:
|
||||
patches_clip = {}
|
||||
for key in weights_clip:
|
||||
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
|
||||
hook.weights = patches_model
|
||||
hook.weights_clip = patches_clip
|
||||
hook.need_weight_init = False
|
||||
return hook_group
|
||||
|
||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||
if model is None:
|
||||
return None
|
||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||
if discard_model_sampling:
|
||||
# do not include ANY model_sampling components of the model that should act as a patch
|
||||
for key in list(patches_model.keys()):
|
||||
if key.startswith("model_sampling"):
|
||||
patches_model.pop(key, None)
|
||||
return patches_model
|
||||
|
||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||
strength_model: float, strength_clip: float):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook()
|
||||
hook_group.add(hook)
|
||||
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
print(f"NOT LOADED {x}")
|
||||
return (new_modelpatcher, new_clip, hook_group)
|
||||
|
||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||
hooks_key = 'hooks'
|
||||
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
|
||||
if hooks_key not in values:
|
||||
return
|
||||
if hooks_key not in c_dict:
|
||||
hooks_value = values.get(hooks_key, None)
|
||||
if hooks_value is not None:
|
||||
c_dict[hooks_key] = hooks_value
|
||||
return
|
||||
# otherwise, need to combine with minimum duplication via cache
|
||||
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
|
||||
cached_hooks = cache.get(hooks_tuple, None)
|
||||
if cached_hooks is None:
|
||||
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
|
||||
cache[hooks_tuple] = new_hooks
|
||||
c_dict[hooks_key] = new_hooks
|
||||
else:
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
if append_hooks and k == 'hooks':
|
||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||
else:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||
if hooks is None:
|
||||
return cond
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||
|
||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||
if timestep_range is None:
|
||||
return cond
|
||||
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
|
||||
"end_percent": timestep_range[1]})
|
||||
|
||||
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
|
||||
if mask is None:
|
||||
return cond
|
||||
set_area_to_bounds = False
|
||||
if set_cond_area != 'default':
|
||||
set_area_to_bounds = True
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
return conditioning_set_values(cond, {'mask': mask,
|
||||
'set_area_to_bounds': set_area_to_bounds,
|
||||
'mask_strength': strength})
|
||||
|
||||
def combine_conditioning(conds: list):
|
||||
combined_conds = []
|
||||
for cond in conds:
|
||||
combined_conds.extend(cond)
|
||||
return combined_conds
|
||||
|
||||
def combine_with_new_conds(conds: list, new_conds: list):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
final_conds = []
|
||||
for c in conds:
|
||||
# first, apply lora_hook to conditioning, if provided
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to conditioning
|
||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||
# apply timesteps, if present
|
||||
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
|
||||
# finally, apply mask to conditioning and store
|
||||
final_conds.append(c)
|
||||
return final_conds
|
||||
|
||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, masked_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to new conditioning, if provided
|
||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||
# apply timesteps, if present
|
||||
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, masked_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||
new_c = conditioning_set_values(new_c, {'default': True})
|
||||
# apply timesteps, if present
|
||||
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
@@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
||||
from . import utils
|
||||
from . import deis
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
@@ -43,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||
epsilon = 1e-5 # avoid log(0)
|
||||
x = torch.linspace(0, 1, n, device=device)
|
||||
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||
sigmas = clamp(torch.exp(lmb))
|
||||
return sigmas
|
||||
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
@@ -152,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -161,14 +175,42 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
if sigma_down == 0:
|
||||
x = denoised
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i + 1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if eta > 0:
|
||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
@@ -243,6 +285,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -269,6 +314,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
@@ -509,6 +586,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -541,6 +621,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||
|
||||
# logged_x = x.unsqueeze(0)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
if sigmas[i] == 1.0:
|
||||
sigma_s = 0.9999
|
||||
else:
|
||||
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_down - t_i
|
||||
s = t_i + r * h
|
||||
sigma_s = sigma_fn(s)
|
||||
# sigma_s = sigmas[i+1]
|
||||
sigma_s_i_ratio = sigma_s / sigmas[i]
|
||||
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
|
||||
D_i = model(u, sigma_s * s_in, **extra_args)
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
|
||||
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
@@ -1016,7 +1145,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
@@ -1043,8 +1171,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
x = denoised + d * sigma_down
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
old_uncond_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||
old_uncond_denoised = uncond_denoised
|
||||
return x
|
||||
|
||||
@@ -4,6 +4,7 @@ class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 0.3920, 0.4054, 0.4549],
|
||||
[-0.2634, -0.0196, 0.0653],
|
||||
[ 0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
[ 0.3651, 0.4232, 0.4341],
|
||||
[-0.2533, -0.0042, 0.1068],
|
||||
[ 0.1076, 0.1111, -0.0362],
|
||||
[-0.3165, -0.2492, -0.2188]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
|
||||
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
class SDXL_Playground_2_5(LatentFormat):
|
||||
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[ 0.0028, 0.0312, 0.0650],
|
||||
[ 0.1848, 0.0762, 0.0360],
|
||||
[ 0.0944, 0.0360, 0.0889],
|
||||
[ 0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[ 0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[ 0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[ 0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259]
|
||||
[-0.0922, -0.0175, 0.0749],
|
||||
[ 0.0311, 0.0633, 0.0954],
|
||||
[ 0.1994, 0.0927, 0.0458],
|
||||
[ 0.0856, 0.0339, 0.0902],
|
||||
[ 0.0587, 0.0272, -0.0496],
|
||||
[-0.0006, 0.1104, 0.0309],
|
||||
[ 0.0978, 0.0306, 0.0427],
|
||||
[-0.0042, 0.1038, 0.1358],
|
||||
[-0.0194, 0.0020, 0.0669],
|
||||
[-0.0488, 0.0130, -0.0268],
|
||||
[ 0.0922, 0.0988, 0.0951],
|
||||
[-0.0278, 0.0524, -0.0542],
|
||||
[ 0.0332, 0.0456, 0.0895],
|
||||
[-0.0069, -0.0030, -0.0810],
|
||||
[-0.0596, -0.0465, -0.0293],
|
||||
[-0.1448, -0.1463, -0.1189]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
|
||||
self.taesd_decoder_name = "taesd3_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -141,30 +145,210 @@ class StableAudio1(LatentFormat):
|
||||
latent_channels = 64
|
||||
|
||||
class Flux(SD3):
|
||||
latent_channels = 16
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0404, 0.0159, 0.0609],
|
||||
[ 0.0043, 0.0298, 0.0850],
|
||||
[ 0.0328, -0.0749, -0.0503],
|
||||
[-0.0245, 0.0085, 0.0549],
|
||||
[ 0.0966, 0.0894, 0.0530],
|
||||
[ 0.0035, 0.0399, 0.0123],
|
||||
[ 0.0583, 0.1184, 0.1262],
|
||||
[-0.0191, -0.0206, -0.0306],
|
||||
[-0.0324, 0.0055, 0.1001],
|
||||
[ 0.0955, 0.0659, -0.0545],
|
||||
[-0.0504, 0.0231, -0.0013],
|
||||
[ 0.0500, -0.0008, -0.0088],
|
||||
[ 0.0982, 0.0941, 0.0976],
|
||||
[-0.1233, -0.0280, -0.0897],
|
||||
[-0.0005, -0.0530, -0.0020],
|
||||
[-0.1273, -0.0932, -0.0680]
|
||||
[-0.0346, 0.0244, 0.0681],
|
||||
[ 0.0034, 0.0210, 0.0687],
|
||||
[ 0.0275, -0.0668, -0.0433],
|
||||
[-0.0174, 0.0160, 0.0617],
|
||||
[ 0.0859, 0.0721, 0.0329],
|
||||
[ 0.0004, 0.0383, 0.0115],
|
||||
[ 0.0405, 0.0861, 0.0915],
|
||||
[-0.0236, -0.0185, -0.0259],
|
||||
[-0.0245, 0.0250, 0.1180],
|
||||
[ 0.1008, 0.0755, -0.0421],
|
||||
[-0.0515, 0.0201, 0.0011],
|
||||
[ 0.0428, -0.0012, -0.0036],
|
||||
[ 0.0817, 0.0765, 0.0749],
|
||||
[-0.1264, -0.0522, -0.1103],
|
||||
[-0.0280, -0.0881, -0.0499],
|
||||
[-0.1262, -0.0982, -0.0778]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.taesd_decoder_name = "taef1_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
|
||||
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
|
||||
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
|
||||
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
|
||||
0.959253732819592, 0.8244560132752793, 0.917259975397747,
|
||||
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0069, -0.0045, 0.0018],
|
||||
[ 0.0154, -0.0692, -0.0274],
|
||||
[ 0.0333, 0.0019, 0.0206],
|
||||
[-0.1390, 0.0628, 0.1678],
|
||||
[-0.0725, 0.0134, -0.1898],
|
||||
[ 0.0074, -0.0270, -0.0209],
|
||||
[-0.0176, -0.0277, -0.0221],
|
||||
[ 0.5294, 0.5204, 0.3852],
|
||||
[-0.0326, -0.0446, -0.0143],
|
||||
[-0.0659, 0.0153, -0.0153],
|
||||
[ 0.0185, -0.0217, 0.0014],
|
||||
[-0.0396, -0.0495, -0.0281]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
|
||||
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
|
||||
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
|
||||
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
|
||||
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
|
||||
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
|
||||
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
|
||||
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
|
||||
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
|
||||
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
|
||||
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
|
||||
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
|
||||
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
|
||||
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
|
||||
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
|
||||
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
|
||||
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
|
||||
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
|
||||
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
|
||||
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
|
||||
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
|
||||
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
|
||||
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
|
||||
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
|
||||
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
|
||||
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
|
||||
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
|
||||
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
|
||||
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
|
||||
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
|
||||
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
|
||||
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
|
||||
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
|
||||
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
|
||||
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
|
||||
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
|
||||
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
|
||||
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
|
||||
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
|
||||
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
|
||||
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
|
||||
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
|
||||
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
|
||||
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
|
||||
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
|
||||
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
|
||||
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
|
||||
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
|
||||
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
|
||||
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
|
||||
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
|
||||
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
|
||||
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
|
||||
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
|
||||
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
|
||||
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
|
||||
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
|
||||
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
|
||||
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
|
||||
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
|
||||
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
|
||||
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
|
||||
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
|
||||
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
|
||||
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
|
||||
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
|
||||
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
|
||||
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
|
||||
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
|
||||
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
|
||||
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
|
||||
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
|
||||
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
|
||||
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
|
||||
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
|
||||
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
|
||||
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
|
||||
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
|
||||
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
|
||||
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
|
||||
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
|
||||
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
|
||||
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
|
||||
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
|
||||
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
|
||||
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
|
||||
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
|
||||
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
|
||||
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
|
||||
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
|
||||
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
|
||||
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
|
||||
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
|
||||
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
|
||||
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
|
||||
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
|
||||
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
|
||||
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
|
||||
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
|
||||
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
|
||||
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
|
||||
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
|
||||
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
|
||||
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
|
||||
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
|
||||
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
|
||||
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
|
||||
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
|
||||
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
|
||||
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
|
||||
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
|
||||
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
|
||||
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
|
||||
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
|
||||
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
|
||||
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
|
||||
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
|
||||
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
|
||||
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
|
||||
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
|
||||
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
|
||||
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
|
||||
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
|
||||
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
|
||||
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
|
||||
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
|
||||
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
@@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
@@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
@@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
|
||||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
|
||||
@@ -437,7 +437,8 @@ class MMDiT(nn.Module):
|
||||
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
@@ -458,15 +459,36 @@ class MMDiT(nn.Module):
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
if len(self.double_layers) > 0:
|
||||
for layer in self.double_layers:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.double_layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
cx = torch.cat([c, x], dim=1)
|
||||
for layer in self.single_layers:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.single_layers):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@@ -1,8 +1,27 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
205
comfy/ldm/flux/controlnet.py
Normal file
205
comfy/ldm/flux/controlnet.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
#modified to support different types of flux controlnets
|
||||
|
||||
import torch
|
||||
import math
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class MistolineCondDownsamplBlock(nn.Module):
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
class MistolineControlnetBlock(nn.Module):
|
||||
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.linear(x))
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
self.main_model_double = 19
|
||||
self.main_model_single = 38
|
||||
|
||||
self.mistoline = mistoline
|
||||
# add ControlNet blocks
|
||||
if self.mistoline:
|
||||
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
self.controlnet_blocks.append(control_block())
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth_single_blocks):
|
||||
self.controlnet_single_blocks.append(control_block())
|
||||
|
||||
self.num_union_modes = num_union_modes
|
||||
self.controlnet_mode_embedder = None
|
||||
if self.num_union_modes > 0:
|
||||
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.latent_input = latent_input
|
||||
if control_latent_channels is None:
|
||||
control_latent_channels = self.in_channels
|
||||
else:
|
||||
control_latent_channels *= 2 * 2 #patch size
|
||||
|
||||
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if not self.latent_input:
|
||||
if self.mistoline:
|
||||
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control_type: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||
txt = torch.cat([control_cond, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
controlnet_double = ()
|
||||
|
||||
for i in range(len(self.double_blocks)):
|
||||
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
controlnet_single = ()
|
||||
|
||||
for i in range(len(self.single_blocks)):
|
||||
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||
if self.latent_input:
|
||||
out_input = ()
|
||||
for x in controlnet_double:
|
||||
out_input += (x,) * repeat
|
||||
else:
|
||||
out_input = (controlnet_double * repeat)
|
||||
|
||||
out = {"input": out_input[:self.main_model_double]}
|
||||
if len(controlnet_single) > 0:
|
||||
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||
out_output = ()
|
||||
if self.latent_input:
|
||||
for x in controlnet_single:
|
||||
out_output += (x,) * repeat
|
||||
else:
|
||||
out_output = (controlnet_single * repeat)
|
||||
out["output"] = out_output[:self.main_model_single]
|
||||
return out
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
elif self.mistoline:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_cond_block(hint)
|
||||
else:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_hint_block(hint)
|
||||
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
||||
@@ -6,6 +6,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
@@ -170,15 +168,15 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img += img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = txt.clip(-65504, 65504)
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
@@ -233,7 +231,7 @@ class SingleStreamBlock(nn.Module):
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = x.clip(-65504, 65504)
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
@@ -29,6 +30,7 @@ class FluxParams:
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
@@ -38,13 +40,14 @@ class Flux(nn.Module):
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -83,7 +86,8 @@ class Flux(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@@ -94,7 +98,10 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
@@ -106,26 +113,62 @@ class Flux(nn.Module):
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y)
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
@@ -133,10 +176,10 @@ class Flux(nn.Module):
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
||||
25
comfy/ldm/flux/redux.py
Normal file
25
comfy/ldm/flux/redux.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
||||
559
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
559
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
@@ -0,0 +1,559 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
from .rope_mixed import (
|
||||
compute_mixed_rotation,
|
||||
create_position_matrix,
|
||||
)
|
||||
from .temporal_rope import apply_rotary_emb_qk_real
|
||||
from .utils import (
|
||||
AttentionPool,
|
||||
modulate,
|
||||
)
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.ops
|
||||
|
||||
|
||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||
# Normalize and modulate
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
||||
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||
|
||||
return x_modulated
|
||||
|
||||
|
||||
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||
# Apply tanh to gate
|
||||
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||
|
||||
# Normalize and apply gated scaling
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
||||
|
||||
# Apply residual connection
|
||||
output = x + x_normed
|
||||
|
||||
return output
|
||||
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_x: int,
|
||||
dim_y: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
update_y: bool = True,
|
||||
out_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
softmax_scale: Optional[float] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_x = dim_x
|
||||
self.dim_y = dim_y
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_x // num_heads
|
||||
self.attn_drop = attn_drop
|
||||
self.update_y = update_y
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.softmax_scale = softmax_scale
|
||||
if dim_x % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||
)
|
||||
|
||||
# Input layers.
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||
# Project text features to match visual features (dim_y -> dim_x)
|
||||
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||
self.proj_y = (
|
||||
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
|
||||
if update_y
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (B, N, dim_x)
|
||||
y: torch.Tensor, # (B, L, dim_y)
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
rope_sin = rope_rotation.get("rope_sin")
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
|
||||
# Process visual features
|
||||
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
# assert qkv_x.dtype == torch.bfloat16
|
||||
# qkv_x = all_to_all_collect_tokens(
|
||||
# qkv_x, self.num_heads
|
||||
# ) # (3, B, N, local_h, head_dim)
|
||||
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||
o[:, :y.shape[1]] = y
|
||||
|
||||
y = self.proj_y(o)
|
||||
# print("ox", x)
|
||||
# print("oy", y)
|
||||
return x, y
|
||||
|
||||
|
||||
class AsymmetricJointBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size_x: int,
|
||||
hidden_size_y: int,
|
||||
num_heads: int,
|
||||
*,
|
||||
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||
update_y: bool = True, # Whether to update text tokens in this block.
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.update_y = update_y
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
|
||||
if self.update_y:
|
||||
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
|
||||
else:
|
||||
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
|
||||
|
||||
# Self-attention:
|
||||
self.attn = AsymmetricAttention(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads=num_heads,
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||
self.mlp_x = FeedForward(
|
||||
in_features=hidden_size_x,
|
||||
hidden_size=mlp_hidden_dim_x,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# MLP for text not needed in last block.
|
||||
if self.update_y:
|
||||
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
||||
self.mlp_y = FeedForward(
|
||||
in_features=hidden_size_y,
|
||||
hidden_size=mlp_hidden_dim_y,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
|
||||
Args:
|
||||
x: (B, N, dim) tensor of visual tokens
|
||||
c: (B, dim) tensor of conditioned features
|
||||
y: (B, L, dim) tensor of text tokens
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim) tensor of visual tokens after block
|
||||
y: (B, L, dim) tensor of text tokens after block
|
||||
"""
|
||||
N = x.size(1)
|
||||
|
||||
c = F.silu(c)
|
||||
mod_x = self.mod_x(c)
|
||||
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
||||
|
||||
mod_y = self.mod_y(c)
|
||||
if self.update_y:
|
||||
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||
else:
|
||||
scale_msa_y = mod_y
|
||||
|
||||
# Self-attention block.
|
||||
x_attn, y_attn = self.attn(
|
||||
x,
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
assert x_attn.size(1) == N
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
if self.update_y:
|
||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||
|
||||
# MLP block.
|
||||
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||
if self.update_y:
|
||||
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||
|
||||
return x, y
|
||||
|
||||
def ff_block_x(self, x, scale_x, gate_x):
|
||||
x_mod = modulated_rmsnorm(x, scale_x)
|
||||
x_res = self.mlp_x(x_mod)
|
||||
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||
return x
|
||||
|
||||
def ff_block_y(self, y, scale_y, gate_y):
|
||||
y_mod = modulated_rmsnorm(y, scale_y)
|
||||
y_res = self.mlp_y(y_mod)
|
||||
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||
return y
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
out_channels,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
|
||||
)
|
||||
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = F.silu(c)
|
||||
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class AsymmDiTJoint(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
|
||||
Ingests text embeddings instead of a label.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size_x=1152,
|
||||
hidden_size_y=1152,
|
||||
depth=48,
|
||||
num_heads=16,
|
||||
mlp_ratio_x=8.0,
|
||||
mlp_ratio_y=4.0,
|
||||
use_t5: bool = False,
|
||||
t5_feat_dim: int = 4096,
|
||||
t5_token_length: int = 256,
|
||||
learn_sigma=True,
|
||||
patch_embed_bias: bool = True,
|
||||
timestep_mlp_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
timestep_scale: Optional[float] = None,
|
||||
use_extended_posenc: bool = False,
|
||||
posenc_preserve_area: bool = False,
|
||||
rope_theta: float = 10000.0,
|
||||
image_model=None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.head_dim = (
|
||||
hidden_size_x // num_heads
|
||||
) # Head dimension and count is determined by visual.
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.use_extended_posenc = use_extended_posenc
|
||||
self.posenc_preserve_area = posenc_preserve_area
|
||||
self.use_t5 = use_t5
|
||||
self.t5_token_length = t5_token_length
|
||||
self.t5_feat_dim = t5_feat_dim
|
||||
self.rope_theta = (
|
||||
rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
||||
)
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size_x,
|
||||
bias=patch_embed_bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
# Conditionings
|
||||
# Timestep
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if self.use_t5:
|
||||
# Caption Pooling (T5)
|
||||
self.t5_y_embedder = AttentionPool(
|
||||
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# Dense Embedding Projection (T5)
|
||||
self.t5_yproj = operations.Linear(
|
||||
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Initialize pos_frequencies as an empty parameter.
|
||||
self.pos_frequencies = nn.Parameter(
|
||||
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
assert not self.attend_to_padding
|
||||
|
||||
# for depth 48:
|
||||
# b = 0: AsymmetricJointBlock, update_y=True
|
||||
# b = 1: AsymmetricJointBlock, update_y=True
|
||||
# ...
|
||||
# b = 46: AsymmetricJointBlock, update_y=True
|
||||
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
||||
blocks = []
|
||||
for b in range(depth):
|
||||
# Joint multi-modal block
|
||||
update_y = b < depth - 1
|
||||
block = AsymmetricJointBlock(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads,
|
||||
mlp_ratio_x=mlp_ratio_x,
|
||||
mlp_ratio_y=mlp_ratio_y,
|
||||
update_y=update_y,
|
||||
attend_to_padding=attend_to_padding,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C=12, T, H, W) tensor of visual tokens
|
||||
|
||||
Returns:
|
||||
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
||||
"""
|
||||
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
t5_feat: torch.Tensor,
|
||||
t5_mask: torch.Tensor,
|
||||
):
|
||||
"""Prepare input and conditioning embeddings."""
|
||||
# Visual patch embeddings with positional encoding.
|
||||
T, H, W = x.shape[-3:]
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||
assert x.ndim == 3
|
||||
B = x.size(0)
|
||||
|
||||
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
N = T * pH * pW
|
||||
assert x.size(1) == N
|
||||
pos = create_position_matrix(
|
||||
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||
) # (N, 3)
|
||||
rope_cos, rope_sin = compute_mixed_rotation(
|
||||
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||
) # Each are (N, num_heads, dim // 2)
|
||||
|
||||
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||
|
||||
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||
|
||||
c = c_t + t5_y_pool
|
||||
|
||||
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
||||
|
||||
return x, c, y_feat, rope_cos, rope_sin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: List[torch.Tensor],
|
||||
attention_mask: List[torch.Tensor],
|
||||
num_tokens=256,
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
control=None, transformer_options={}, **kwargs
|
||||
):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
y_feat = context
|
||||
y_mask = attention_mask
|
||||
sigma = timestep
|
||||
"""Forward pass of DiT.
|
||||
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
sigma: (B,) tensor of noise standard deviations
|
||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||
x, sigma, y_feat, y_mask
|
||||
)
|
||||
del y_mask
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(
|
||||
args["img"],
|
||||
args["vec"],
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||
T=T,
|
||||
hp=H // self.patch_size,
|
||||
wp=W // self.patch_size,
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
c=self.out_channels,
|
||||
)
|
||||
|
||||
return -x
|
||||
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from itertools import repeat
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
frequency_embedding_size: int = 256,
|
||||
*,
|
||||
bias: bool = True,
|
||||
timestep_scale: Optional[float] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.timestep_scale = timestep_scale
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
freqs.mul_(-math.log(max_period) / half).exp_()
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, out_dtype):
|
||||
if self.timestep_scale is not None:
|
||||
t = t * self.timestep_scale
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_size: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
# keep parameter count and computation constant compared to standard FFN
|
||||
hidden_size = int(2 * hidden_size / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
||||
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||
x = self.w2(F.silu(x) * gate)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.flatten = flatten
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = operations.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert norm_layer is None
|
||||
self.norm = (
|
||||
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, _C, T, H, W = x.shape
|
||||
if not self.dynamic_img_pad:
|
||||
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
else:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||
x = self.proj(x)
|
||||
|
||||
# Flatten temporal and spatial dimensions.
|
||||
if not self.flatten:
|
||||
raise NotImplementedError("Must flatten output.")
|
||||
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
|
||||
# import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def centers(start: float, stop, num, dtype=None, device=None):
|
||||
"""linspace through bin centers.
|
||||
|
||||
Args:
|
||||
start (float): Start of the range.
|
||||
stop (float): End of the range.
|
||||
num (int): Number of points.
|
||||
dtype (torch.dtype): Data type of the points.
|
||||
device (torch.device): Device of the points.
|
||||
|
||||
Returns:
|
||||
centers (Tensor): Centers of the bins. Shape: (num,).
|
||||
"""
|
||||
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
|
||||
# @functools.lru_cache(maxsize=1)
|
||||
def create_position_matrix(
|
||||
T: int,
|
||||
pH: int,
|
||||
pW: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
target_area: float = 36864,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
T: int - Temporal dimension
|
||||
pH: int - Height dimension after patchify
|
||||
pW: int - Width dimension after patchify
|
||||
|
||||
Returns:
|
||||
pos: [T * pH * pW, 3] - position matrix
|
||||
"""
|
||||
# Create 1D tensors for each dimension
|
||||
t = torch.arange(T, dtype=dtype)
|
||||
|
||||
# Positionally interpolate to area 36864.
|
||||
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
||||
# This automatically scales rope positions when the resolution changes.
|
||||
# We use a large target area so the model is more sensitive
|
||||
# to changes in the learned pos_frequencies matrix.
|
||||
scale = math.sqrt(target_area / (pW * pH))
|
||||
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
||||
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
||||
|
||||
# Use meshgrid to create 3D grids
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||
|
||||
# Stack and reshape the grids.
|
||||
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||
pos = pos.to(dtype=dtype, device=device)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def compute_mixed_rotation(
|
||||
freqs: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
||||
|
||||
Args:
|
||||
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
||||
pos: [N, 3] - position of each token
|
||||
num_heads: int
|
||||
|
||||
Returns:
|
||||
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||
"""
|
||||
assert freqs.ndim == 3
|
||||
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
||||
freqs_cos = torch.cos(freqs_sum)
|
||||
freqs_sin = torch.sin(freqs_sum)
|
||||
return freqs_cos, freqs_sin
|
||||
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
|
||||
# Based on Llama3 Implementation.
|
||||
import torch
|
||||
|
||||
|
||||
def apply_rotary_emb_qk_real(
|
||||
xqk: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
||||
|
||||
Args:
|
||||
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
||||
Can be either just query or just key, or both stacked along some batch or * dim.
|
||||
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
||||
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||
"""
|
||||
# Split the last dimension into even and odd parts
|
||||
xqk_even = xqk[..., 0::2]
|
||||
xqk_odd = xqk[..., 1::2]
|
||||
|
||||
# Apply rotation
|
||||
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
||||
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||
|
||||
# Interleave the results back into the original shape
|
||||
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||
return out
|
||||
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||
"""
|
||||
Pool tokens in x using mask.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Args:
|
||||
x: (B, L, D) tensor of tokens.
|
||||
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
Returns:
|
||||
pooled: (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
spatial_dim (int): Number of tokens in sequence length.
|
||||
embed_dim (int): Dimensionality of input tokens.
|
||||
num_heads (int): Number of attention heads.
|
||||
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
|
||||
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
D = x.size(2)
|
||||
|
||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||
|
||||
# Average non-padding token features. These will be used as the query.
|
||||
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||
|
||||
# Concat pooled features to input sequence.
|
||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||
|
||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||
q = self.to_q(x[:, 0]) # (B, D)
|
||||
|
||||
# Extract heads.
|
||||
head_dim = D // self.num_heads
|
||||
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
711
comfy/ldm/genmo/vae/model.py
Normal file
711
comfy/ldm/genmo/vae/model.py
Normal file
@@ -0,0 +1,711 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
class GroupNormSpatial(ops.GroupNorm):
|
||||
"""
|
||||
GroupNorm applied per-frame.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
|
||||
B, C, T, H, W = x.shape
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
# Run group norm in chunks.
|
||||
output = torch.empty_like(x)
|
||||
for b in range(0, B * T, chunk_size):
|
||||
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
||||
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||
|
||||
class PConv3d(ops.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
causal: bool = True,
|
||||
context_parallel: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.causal = causal
|
||||
self.context_parallel = context_parallel
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
stride = cast_tuple(stride, 3)
|
||||
height_pad = (kernel_size[1] - 1) // 2
|
||||
width_pad = (kernel_size[2] - 1) // 2
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(1, 1, 1),
|
||||
padding=(0, height_pad, width_pad),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Compute padding amounts.
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if self.causal:
|
||||
pad_front = context_size
|
||||
pad_back = 0
|
||||
else:
|
||||
pad_front = context_size // 2
|
||||
pad_back = context_size - pad_front
|
||||
|
||||
# Apply padding.
|
||||
assert self.padding_mode == "replicate" # DEBUG
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class Conv1x1(ops.Linear):
|
||||
"""*1x1 Conv implemented with a linear layer."""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
|
||||
super().__init__(in_features, out_features, *args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, *] or [B, *, C].
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
|
||||
"""
|
||||
x = x.movedim(1, -1)
|
||||
x = super().forward(x)
|
||||
x = x.movedim(-1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceTime(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
temporal_expansion: int,
|
||||
spatial_expansion: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# When printed, this module should show the temporal and spatial expansion factors.
|
||||
def extra_repr(self):
|
||||
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
|
||||
Returns:
|
||||
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
|
||||
"""
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
|
||||
st=self.temporal_expansion,
|
||||
sh=self.spatial_expansion,
|
||||
sw=self.spatial_expansion,
|
||||
)
|
||||
|
||||
# cp_rank, _ = cp.get_cp_rank_size()
|
||||
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||
# Drop the first self.temporal_expansion - 1 frames.
|
||||
# This is because we always want the 3x3x3 conv filter to only apply
|
||||
# to the first frame, and the first frame doesn't need to be repeated.
|
||||
assert all(x.shape)
|
||||
x = x[:, :, self.temporal_expansion - 1 :]
|
||||
assert all(x.shape)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def norm_fn(
|
||||
in_channels: int,
|
||||
affine: bool = True,
|
||||
):
|
||||
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block that preserves the spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
*,
|
||||
affine: bool = True,
|
||||
attn_block: Optional[nn.Module] = None,
|
||||
causal: bool = True,
|
||||
prune_bottleneck: bool = False,
|
||||
padding_mode: str,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
assert causal
|
||||
self.stack = nn.Sequential(
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
)
|
||||
|
||||
self.attn_block = attn_block if attn_block else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
residual = x
|
||||
x = self.stack(x)
|
||||
x = x + residual
|
||||
del residual
|
||||
|
||||
return self.attn_block(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
out_bias: bool = True,
|
||||
qk_norm: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute temporal self-attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
chunk_size: Chunk size for large tensors.
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
if T == 1:
|
||||
# No attention for single frame.
|
||||
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||
qkv = self.qkv(x)
|
||||
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||
x = self.out(x)
|
||||
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||
|
||||
# 1D temporal attention.
|
||||
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||
# Output: x with shape [B, num_heads, t, head_dim]
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
|
||||
|
||||
if self.qk_norm:
|
||||
q = F.normalize(q, p=2, dim=-1)
|
||||
k = F.normalize(k, p=2, dim=-1)
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||
|
||||
assert x.size(0) == q.size(0)
|
||||
|
||||
x = self.out(x)
|
||||
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
**attn_kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = norm_fn(dim)
|
||||
self.attn = Attention(dim, **attn_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.attn(self.norm(x))
|
||||
|
||||
|
||||
class CausalUpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
*,
|
||||
temporal_expansion: int = 2,
|
||||
spatial_expansion: int = 2,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = []
|
||||
for _ in range(num_res_blocks):
|
||||
blocks.append(block_fn(in_channels, **block_kwargs))
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# Change channels in the final convolution layer.
|
||||
self.proj = Conv1x1(
|
||||
in_channels,
|
||||
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||
)
|
||||
|
||||
self.d2st = DepthToSpaceTime(
|
||||
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
x = self.proj(x)
|
||||
x = self.d2st(x)
|
||||
return x
|
||||
|
||||
|
||||
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||
attn_block = AttentionBlock(channels) if has_attention else None
|
||||
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||
|
||||
|
||||
class DownsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks,
|
||||
*,
|
||||
temporal_reduction=2,
|
||||
spatial_reduction=2,
|
||||
**block_kwargs,
|
||||
):
|
||||
"""
|
||||
Downsample block for the VAE encoder.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
num_res_blocks: Number of residual blocks.
|
||||
temporal_reduction: Temporal reduction factor.
|
||||
spatial_reduction: Spatial reduction factor.
|
||||
"""
|
||||
super().__init__()
|
||||
layers = []
|
||||
|
||||
# Change the channel count in the strided convolution.
|
||||
# This lets the ResBlock have uniform channel count,
|
||||
# as in ConvNeXt.
|
||||
assert in_channels != out_channels
|
||||
layers.append(
|
||||
PConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
# First layer in each block always uses replicate padding
|
||||
padding_mode="replicate",
|
||||
bias=block_kwargs["bias"],
|
||||
)
|
||||
)
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(block_fn(out_channels, **block_kwargs))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||
num_freqs = (stop - start) // step
|
||||
assert inputs.ndim == 5
|
||||
C = inputs.size(1)
|
||||
|
||||
# Create Base 2 Fourier features.
|
||||
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
|
||||
assert num_freqs == len(freqs)
|
||||
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
|
||||
C = inputs.shape[1]
|
||||
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
|
||||
|
||||
# Interleaved repeat of input channels to match w.
|
||||
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
inputs,
|
||||
torch.sin(h),
|
||||
torch.cos(h),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
|
||||
super().__init__()
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Add Fourier features to inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input tensor. Shape: [B, C, T, H, W]
|
||||
|
||||
Returns:
|
||||
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
|
||||
"""
|
||||
return add_fourier_features(inputs, self.start, self.stop, self.step)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int = 3,
|
||||
latent_dim: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
temporal_expansions: Optional[List[int]] = None,
|
||||
spatial_expansions: Optional[List[int]] = None,
|
||||
has_attention: List[bool],
|
||||
output_norm: bool = True,
|
||||
nonlinearity: str = "silu",
|
||||
output_nonlinearity: str = "silu",
|
||||
causal: bool = True,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_channels = latent_dim
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.output_nonlinearity = output_nonlinearity
|
||||
assert nonlinearity == "silu"
|
||||
assert causal
|
||||
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
self.num_up_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||
|
||||
blocks = []
|
||||
|
||||
first_block = [
|
||||
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
] # Input layer.
|
||||
# First set of blocks preserve channel count.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
first_block.append(
|
||||
block_fn(
|
||||
ch[-1],
|
||||
has_attention=has_attention[-1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*first_block))
|
||||
|
||||
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
|
||||
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
|
||||
|
||||
upsample_block_fn = CausalUpsampleBlock
|
||||
|
||||
for i in range(self.num_up_blocks):
|
||||
block = upsample_block_fn(
|
||||
ch[-i - 1],
|
||||
ch[-i - 2],
|
||||
num_res_blocks=num_res_blocks[-i - 2],
|
||||
has_attention=has_attention[-i - 2],
|
||||
temporal_expansion=temporal_expansions[-i - 1],
|
||||
spatial_expansion=spatial_expansions[-i - 1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
assert not output_norm
|
||||
|
||||
# Last block. Preserve channel count.
|
||||
last_block = []
|
||||
for _ in range(num_res_blocks[0]):
|
||||
last_block.append(
|
||||
block_fn(
|
||||
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*last_block))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.output_proj = Conv1x1(ch[0], out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
|
||||
|
||||
Returns:
|
||||
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
|
||||
T + 1 = (t - 1) * 4.
|
||||
H = h * 16, W = w * 16.
|
||||
"""
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.output_nonlinearity == "silu":
|
||||
x = F.silu(x, inplace=not self.training)
|
||||
else:
|
||||
assert (
|
||||
not self.output_nonlinearity
|
||||
) # StyleGAN3 omits the to-RGB nonlinearity.
|
||||
|
||||
return self.output_proj(x).contiguous()
|
||||
|
||||
class LatentDistribution:
|
||||
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||
"""Initialize latent distribution.
|
||||
|
||||
Args:
|
||||
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
assert mean.shape == logvar.shape
|
||||
self.mean = mean
|
||||
self.logvar = logvar
|
||||
|
||||
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||
if temperature == 0.0:
|
||||
return self.mean
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||
else:
|
||||
assert noise.device == self.mean.device
|
||||
noise = noise.to(self.mean.dtype)
|
||||
|
||||
if temperature != 1.0:
|
||||
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||
|
||||
# Just Gaussian sample with no scaling of variance.
|
||||
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
latent_dim: int,
|
||||
temporal_reductions: List[int],
|
||||
spatial_reductions: List[int],
|
||||
prune_bottlenecks: List[bool],
|
||||
has_attentions: List[bool],
|
||||
affine: bool = True,
|
||||
bias: bool = True,
|
||||
input_is_conv_1x1: bool = False,
|
||||
padding_mode: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_reductions = temporal_reductions
|
||||
self.spatial_reductions = spatial_reductions
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.fourier_features = FourierFeatures()
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
num_down_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == num_down_blocks + 2
|
||||
|
||||
layers = (
|
||||
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||
if not input_is_conv_1x1
|
||||
else [Conv1x1(in_channels, ch[0])]
|
||||
)
|
||||
|
||||
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||
assert len(has_attentions) == num_down_blocks + 2
|
||||
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||
|
||||
for _ in range(num_res_blocks[0]):
|
||||
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||
prune_bottlenecks = prune_bottlenecks[1:]
|
||||
has_attentions = has_attentions[1:]
|
||||
|
||||
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||
for i in range(num_down_blocks):
|
||||
layer = DownsampleBlock(
|
||||
ch[i],
|
||||
ch[i + 1],
|
||||
num_res_blocks=num_res_blocks[i + 1],
|
||||
temporal_reduction=temporal_reductions[i],
|
||||
spatial_reduction=spatial_reductions[i],
|
||||
prune_bottleneck=prune_bottlenecks[i],
|
||||
has_attention=has_attentions[i],
|
||||
affine=affine,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
layers.append(layer)
|
||||
|
||||
# Additional blocks.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
# Output layers.
|
||||
self.output_norm = norm_fn(ch[-1])
|
||||
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||
|
||||
@property
|
||||
def temporal_downsample(self):
|
||||
return math.prod(self.temporal_reductions)
|
||||
|
||||
@property
|
||||
def spatial_downsample(self):
|
||||
return math.prod(self.spatial_reductions)
|
||||
|
||||
def forward(self, x) -> LatentDistribution:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||
logvar: Shape: [B, latent_dim, t, h, w].
|
||||
"""
|
||||
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||
x = self.fourier_features(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
x = self.output_norm(x)
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.output_proj(x)
|
||||
|
||||
means, logvar = torch.chunk(x, 2, dim=1)
|
||||
|
||||
assert means.ndim == 5
|
||||
assert logvar.shape == means.shape
|
||||
assert means.size(1) == self.latent_dim
|
||||
|
||||
return LatentDistribution(means, logvar)
|
||||
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
in_channels=15,
|
||||
base_channels=64,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
temporal_reductions=[1, 2, 3],
|
||||
spatial_reductions=[2, 2, 2],
|
||||
prune_bottlenecks=[False, False, False, False, False],
|
||||
has_attentions=[False, True, True, True, True],
|
||||
affine=True,
|
||||
bias=True,
|
||||
input_is_conv_1x1=True,
|
||||
padding_mode="replicate"
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
temporal_expansions=[1, 2, 3],
|
||||
spatial_expansions=[2, 2, 2],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
has_attention=[False, False, False, False, False],
|
||||
padding_mode="replicate",
|
||||
output_norm=False,
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x).mode()
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
@@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
|
||||
style=None,
|
||||
return_dict=False,
|
||||
control=None,
|
||||
transformer_options=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
@@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
encoder_hidden_states = context
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
@@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
controls = None
|
||||
if control:
|
||||
controls = control.get("output", None)
|
||||
@@ -372,12 +373,23 @@ class HunYuanDiT(nn.Module):
|
||||
for layer, block in enumerate(self.blocks):
|
||||
if layer > self.depth // 2:
|
||||
if controls is not None:
|
||||
skip = skips.pop() + controls.pop()
|
||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
||||
skip = None
|
||||
|
||||
if ("double_block", layer) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
|
||||
|
||||
if layer < (self.depth // 2 - 1):
|
||||
skips.append(x)
|
||||
|
||||
514
comfy/ldm/lightricks/model.py
Normal file
514
comfy/ldm/lightricks/model.py
Normal file
@@ -0,0 +1,514 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
# else:
|
||||
# self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
"""
|
||||
For PixArt-Alpha.
|
||||
|
||||
Reference:
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = 0.0
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
x,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
output_width=orig_shape[4],
|
||||
output_num_frames=orig_shape[2],
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
||||
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
def patchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
output_height = output_height // self._patch_size[1]
|
||||
output_width = output_width // self._patch_size[2]
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q) ",
|
||||
f=output_num_frames,
|
||||
h=output_height,
|
||||
w=output_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
dilation = (dilation, 1, 1)
|
||||
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
padding = (0, height_pad, width_pad)
|
||||
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.conv.weight
|
||||
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
@@ -0,0 +1,698 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Any, Mapping, Optional, Tuple, Union, List
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self.blocks_desc = blocks
|
||||
|
||||
in_channels = in_channels * patch_size**2
|
||||
output_channel = base_channels
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in blocks:
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
|
||||
self.down_blocks.append(block)
|
||||
|
||||
# out
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == "per_channel":
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
num_dims = sample.dim()
|
||||
|
||||
if num_dims == 4:
|
||||
# For shape (B, C, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
elif num_dims == 5:
|
||||
# For shape (B, C, F, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
causal (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use causal convolutions or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.layers_per_block = layers_per_block
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.blocks_desc = blocks
|
||||
|
||||
# Compute output channel to be product of all channel-multiplier blocks
|
||||
output_channel = base_channels
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
block_params = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
|
||||
self.up_blocks.append(block)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
# assert target_shape is not None, "target_shape must be provided"
|
||||
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
in_channels, height, width)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
)
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
||||
) -> torch.FloatTensor:
|
||||
for resnet in self.res_blocks:
|
||||
hidden_states = resnet(hidden_states, causal=causal)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, stride, residual=False):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = math.prod(stride) * in_channels
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
)
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, "b d h w c -> b c d h w")
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A Resnet block.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to be `None`):
|
||||
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
||||
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm1 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
self.conv1 = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm2 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
self.conv2 = make_conv_nd(
|
||||
dims,
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.conv_shortcut = (
|
||||
make_linear_nd(
|
||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||
)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.norm3 = (
|
||||
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def encode(self, x):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
||||
|
||||
83
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
83
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .dual_conv3d import DualConv3d
|
||||
from .causal_conv3d import CausalConv3d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def make_conv_nd(
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def make_linear_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias=True,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
elif dims == 3 or dims == (2, 1):
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class DualConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if kernel_size == (1, 1, 1):
|
||||
raise ValueError(
|
||||
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
||||
)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation, dilation)
|
||||
|
||||
# Set parameters for convolutions
|
||||
self.groups = groups
|
||||
self.bias = bias
|
||||
|
||||
# Define the size of the channels after the first convolution
|
||||
intermediate_channels = (
|
||||
out_channels if in_channels < out_channels else in_channels
|
||||
)
|
||||
|
||||
# Define parameters for the first convolution
|
||||
self.weight1 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
intermediate_channels,
|
||||
in_channels // groups,
|
||||
1,
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
)
|
||||
)
|
||||
self.stride1 = (1, stride[1], stride[2])
|
||||
self.padding1 = (0, padding[1], padding[2])
|
||||
self.dilation1 = (1, dilation[1], dilation[2])
|
||||
if bias:
|
||||
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
||||
else:
|
||||
self.register_parameter("bias1", None)
|
||||
|
||||
# Define parameters for the second convolution
|
||||
self.weight2 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
||||
)
|
||||
)
|
||||
self.stride2 = (stride[0], 1, 1)
|
||||
self.padding2 = (padding[0], 0, 0)
|
||||
self.dilation2 = (dilation[0], 1, 1)
|
||||
if bias:
|
||||
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter("bias2", None)
|
||||
|
||||
# Initialize weights and biases
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
|
||||
if self.bias:
|
||||
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
||||
bound1 = 1 / math.sqrt(fan_in1)
|
||||
nn.init.uniform_(self.bias1, -bound1, bound1)
|
||||
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
||||
bound2 = 1 / math.sqrt(fan_in2)
|
||||
nn.init.uniform_(self.bias2, -bound2, bound2)
|
||||
|
||||
def forward(self, x, use_conv3d=False, skip_time_conv=False):
|
||||
if use_conv3d:
|
||||
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
||||
else:
|
||||
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
||||
|
||||
def forward_with_3d(self, x, skip_time_conv):
|
||||
# First convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight1,
|
||||
self.bias1,
|
||||
self.stride1,
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
return x
|
||||
|
||||
# Second convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight2,
|
||||
self.bias2,
|
||||
self.stride2,
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_2d(self, x, skip_time_conv):
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
# First 2D convolution
|
||||
x = rearrange(x, "b c d h w -> (b d) c h w")
|
||||
# Squeeze the depth dimension out of weight1 since it's 1
|
||||
weight1 = self.weight1.squeeze(2)
|
||||
# Select stride, padding, and dilation for the 2D convolution
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
if skip_time_conv:
|
||||
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
||||
return x
|
||||
|
||||
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
||||
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
||||
|
||||
# Reshape weight2 to match the expected dimensions for conv1d
|
||||
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
||||
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.weight2
|
||||
|
||||
|
||||
def test_dual_conv3d_consistency():
|
||||
# Initialize parameters
|
||||
in_channels = 3
|
||||
out_channels = 5
|
||||
kernel_size = (3, 3, 3)
|
||||
stride = (2, 2, 2)
|
||||
padding = (1, 1, 1)
|
||||
|
||||
# Create an instance of the DualConv3d class
|
||||
dual_conv3d = DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Example input tensor
|
||||
test_input = torch.randn(1, 3, 10, 10, 10)
|
||||
|
||||
# Perform forward passes with both 3D and 2D settings
|
||||
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
|
||||
output_2d = dual_conv3d(test_input, use_conv3d=False)
|
||||
|
||||
# Assert that the outputs from both methods are sufficiently close
|
||||
assert torch.allclose(
|
||||
output_conv3d, output_2d, atol=1e-6
|
||||
), "Outputs are not consistent between 3D and 2D convolutions."
|
||||
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self, dim=1, eps=1e-8):
|
||||
super(PixelNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
|
||||
@@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
if mask.shape[1] == 1:
|
||||
s1 += mask
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
@@ -358,7 +361,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
@@ -372,10 +375,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
@@ -393,6 +396,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
return out
|
||||
|
||||
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
||||
SDP_BATCH_LIMIT = 2**15
|
||||
else:
|
||||
#TODO: other GPUs ?
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
@@ -404,10 +414,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if SDP_BATCH_LIMIT >= q.shape[0]:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from ..attention import optimized_attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
import comfy.ops
|
||||
@@ -97,7 +97,7 @@ class PatchEmbed(nn.Module):
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# B, C, H, W = x.shape
|
||||
# if self.img_size is not None:
|
||||
# if self.strict_img_size:
|
||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
@@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv = self.pre_attention(x)
|
||||
q, k, v = self.pre_attention(x)
|
||||
x = optimized_attention(
|
||||
qkv, num_heads=self.num_heads
|
||||
q, k, v, heads=self.num_heads
|
||||
)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
@@ -355,29 +353,9 @@ class RMSNorm(torch.nn.Module):
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
@@ -437,6 +415,7 @@ class DismantledBlock(nn.Module):
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -460,6 +439,24 @@ class DismantledBlock(nn.Module):
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.x_block_self_attn = True
|
||||
self.attn2 = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
self.x_block_self_attn = False
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = operations.LayerNorm(
|
||||
@@ -486,7 +483,11 @@ class DismantledBlock(nn.Module):
|
||||
multiple_of=256,
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if not scale_mod_only:
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
@@ -547,14 +548,64 @@ class DismantledBlock(nn.Module):
|
||||
)
|
||||
return x
|
||||
|
||||
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert self.x_block_self_attn
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return qkv, qkv2, (
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
)
|
||||
|
||||
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
|
||||
assert not self.pre_only
|
||||
attn1 = self.attn.post_attention(attn)
|
||||
attn2 = self.attn2.post_attention(attn2)
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
if self.x_block_self_attn:
|
||||
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
|
||||
attn, _ = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
attn2, _ = optimized_attention(
|
||||
qkv2[0], qkv2[1], qkv2[2],
|
||||
num_heads=self.attn2.num_heads,
|
||||
)
|
||||
return self.post_attention_x(attn, attn2, *intermediates)
|
||||
else:
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
@@ -569,7 +620,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o = []
|
||||
for t in range(3):
|
||||
@@ -577,8 +631,8 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
qkv = tuple(o)
|
||||
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=x_block.attn.num_heads,
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=x_block.attn.num_heads,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
@@ -590,7 +644,14 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
|
||||
else:
|
||||
context = None
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
if x_block.x_block_self_attn:
|
||||
attn2 = optimized_attention(
|
||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||
heads=x_block.attn2.num_heads,
|
||||
)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
return context, x
|
||||
|
||||
|
||||
@@ -605,8 +666,13 @@ class JointBlock(nn.Module):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=x_block_self_attn,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(
|
||||
@@ -662,7 +728,7 @@ class SelfAttentionContext(nn.Module):
|
||||
def forward(self, x):
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.dim_head)
|
||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
||||
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
|
||||
return self.proj(x)
|
||||
|
||||
class ContextProcessorBlock(nn.Module):
|
||||
@@ -721,9 +787,12 @@ class MMDiT(nn.Module):
|
||||
qk_norm: Optional[str] = None,
|
||||
qkv_bias: bool = True,
|
||||
context_processor_layers = None,
|
||||
x_block_self_attn: bool = False,
|
||||
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||
context_size = 4096,
|
||||
num_blocks = None,
|
||||
final_layer = True,
|
||||
skip_blocks = False,
|
||||
dtype = None, #TODO
|
||||
device = None,
|
||||
operations = None,
|
||||
@@ -738,6 +807,7 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||
|
||||
# hidden_size = default(hidden_size, 64 * depth)
|
||||
# num_heads = default(num_heads, hidden_size // 64)
|
||||
@@ -795,26 +865,28 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed = None
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
if not skip_blocks:
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
@@ -877,7 +949,9 @@ class MMDiT(nn.Module):
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
@@ -889,14 +963,25 @@ class MMDiT(nn.Module):
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
blocks = len(self.joint_blocks)
|
||||
for i in range(blocks):
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||
context = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
@@ -914,6 +999,7 @@ class MMDiT(nn.Module):
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
@@ -935,7 +1021,7 @@ class MMDiT(nn.Module):
|
||||
if context is not None:
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context, control)
|
||||
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x[:,:,:hw[-2],:hw[-1]]
|
||||
@@ -949,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from .util import (
|
||||
)
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -47,6 +48,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@@ -819,6 +829,13 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
@@ -842,6 +859,11 @@ class UNetModel(nn.Module):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if "emb_patch" in transformer_patches:
|
||||
patch = transformer_patches["emb_patch"]
|
||||
for p in patch:
|
||||
emb = p(emb, self.model_channels, transformer_options)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@@ -234,6 +234,8 @@ def efficient_dot_product_attention(
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
if mask.shape[1] == 1:
|
||||
return mask
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
|
||||
343
comfy/lora.py
343
comfy/lora.py
@@ -16,8 +16,12 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
import logging
|
||||
import torch
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
@@ -29,7 +33,7 @@ LORA_CLIP_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
def load_lora(lora, to_load, log_missing=True):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@@ -45,10 +49,20 @@ def load_lora(lora, to_load):
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
@@ -68,6 +82,10 @@ def load_lora(lora, to_load):
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif mochi_lora in lora.keys():
|
||||
A_name = mochi_lora
|
||||
B_name = "{}.lora_A".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
@@ -78,7 +96,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@@ -189,17 +207,28 @@ def load_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
if log_missing:
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
|
||||
return patch_dict
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
clip_g_present = False
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
@@ -223,6 +252,7 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
clip_g_present = True
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
@@ -238,10 +268,18 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
t5_index = 1
|
||||
if clip_g_present:
|
||||
t5_index += 1
|
||||
if clip_l_present:
|
||||
t5_index += 1
|
||||
if t5_index == 2:
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||
t5_index += 1
|
||||
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
@@ -265,11 +303,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
@@ -277,6 +318,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
@@ -299,6 +341,10 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||
key_map[key_lora] = to
|
||||
|
||||
|
||||
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
@@ -318,7 +364,274 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map[key_lora] = to
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
|
||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Pad a tensor to a new shape with zeros.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The original tensor to be padded.
|
||||
new_shape (List[int]): The desired shape of the padded tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||
|
||||
Note:
|
||||
If the new shape is smaller than the original tensor in any dimension,
|
||||
the original tensor will be truncated in that dimension.
|
||||
"""
|
||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||
|
||||
if len(new_shape) != len(tensor.shape):
|
||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||
|
||||
# Create a new tensor filled with zeros
|
||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
# Create slicing tuples for both tensors
|
||||
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
|
||||
# Copy the original tensor into the new tensor
|
||||
padded_tensor[new_slices] = tensor[orig_slices]
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
old_weight = None
|
||||
if offset is not None:
|
||||
old_weight = weight
|
||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
diff: torch.Tensor = v[0]
|
||||
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||
if do_pad_weight and diff.shape != weight.shape:
|
||||
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||
|
||||
if strength != 0.0:
|
||||
if diff.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
elif patch_type == "model_as_lora":
|
||||
target_weight: torch.Tensor = v[0]
|
||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
dora_scale = v[5]
|
||||
|
||||
old_glora = False
|
||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||
rank = v[0].shape[0]
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = v[1].shape[0]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / rank
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
17
comfy/lora_convert.py
Normal file
17
comfy/lora_convert.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
return sd
|
||||
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@@ -6,19 +24,25 @@ from comfy.ldm.cascade.stage_b import StageB
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
||||
import comfy.ldm.aura.mmdit
|
||||
import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@@ -75,12 +99,14 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device = device
|
||||
self.current_patcher: 'ModelPatcher' = None
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
if comfy.model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
@@ -99,6 +125,13 @@ class BaseModel(torch.nn.Module):
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._apply_model,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
|
||||
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
@@ -133,8 +166,7 @@ class BaseModel(torch.nn.Module):
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
def concat_cond(self, **kwargs):
|
||||
if len(self.concat_keys) > 0:
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
@@ -173,7 +205,14 @@ class BaseModel(torch.nn.Module):
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
return data
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
if concat_cond is not None:
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
|
||||
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
@@ -226,6 +265,10 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
@@ -499,9 +542,7 @@ class SD_X4Upscaler(BaseModel):
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
@@ -513,18 +554,15 @@ class IP2P:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@@ -685,6 +723,44 @@ class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
try:
|
||||
#Handle Flux control loras dynamically changing the img_in weight.
|
||||
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
|
||||
except:
|
||||
#Some cases like tensorrt might not have the weights accessible
|
||||
num_channels = self.model_config.unet_config["in_channels"]
|
||||
|
||||
out_channels = self.model_config.unet_config["out_channels"]
|
||||
|
||||
if num_channels <= out_channels:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
image = self.process_latent_in(image)
|
||||
if num_channels <= out_channels * 2:
|
||||
return image
|
||||
|
||||
#inpaint model
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.ones_like(noise)[:, :1]
|
||||
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
print(mask.shape)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
@@ -695,3 +771,38 @@ class Flux(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class LTXV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guiding_latent = kwargs.get("guiding_latent", None)
|
||||
if guiding_latent is not None:
|
||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
return out
|
||||
|
||||
@@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||
if context_processor in state_dict_keys:
|
||||
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||
unet_config["x_block_self_attn_layers"] = []
|
||||
for key in state_dict_keys:
|
||||
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
|
||||
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
|
||||
unet_config["x_block_self_attn_layers"].append(int(layer))
|
||||
return unet_config
|
||||
|
||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||
@@ -132,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
patch_size = 2
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
@@ -145,6 +156,38 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "mochi_preview"
|
||||
dit_config["depth"] = 48
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["hidden_size_x"] = 3072
|
||||
dit_config["hidden_size_y"] = 1536
|
||||
dit_config["mlp_ratio_x"] = 4.0
|
||||
dit_config["mlp_ratio_y"] = 4.0
|
||||
dit_config["learn_sigma"] = False
|
||||
dit_config["in_channels"] = 12
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["out_bias"] = True
|
||||
dit_config["attn_drop"] = 0.0
|
||||
dit_config["patch_embed_bias"] = True
|
||||
dit_config["posenc_preserve_area"] = True
|
||||
dit_config["timestep_mlp_bias"] = True
|
||||
dit_config["attend_to_padding"] = False
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
dit_config["use_t5"] = True
|
||||
dit_config["t5_feat_dim"] = 4096
|
||||
dit_config["t5_token_length"] = 256
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
return dit_config
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -286,9 +329,16 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
return None
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
@@ -472,9 +522,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
@@ -495,7 +551,11 @@ def model_config_from_diffusers_unet(state_dict):
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
out_sd = {}
|
||||
|
||||
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||
if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||
elif 'x_embedder.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
@@ -504,10 +564,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ from comfy.cli_args import args
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@@ -44,9 +46,15 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
lowvram_available = True
|
||||
xpu_available = False
|
||||
torch_version = ""
|
||||
try:
|
||||
torch_version = torch.version.__version__
|
||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
lowvram_available = True
|
||||
if args.deterministic:
|
||||
logging.info("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
@@ -66,10 +74,10 @@ if args.directml is not None:
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -139,7 +147,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
||||
logging.info("pytorch version: {}".format(torch_version))
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -189,7 +197,6 @@ VAE_DTYPES = [torch.float32]
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@@ -282,11 +289,27 @@ def module_size(module):
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.weights_loaded = False
|
||||
self.real_model = None
|
||||
self.currently_used = True
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||
|
||||
def _switch_parent(self):
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model()
|
||||
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
@@ -301,34 +324,23 @@ class LoadedModel:
|
||||
return self.model_memory()
|
||||
|
||||
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||
patch_model_to = self.device
|
||||
|
||||
self.model.model_patches_to(self.device)
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
# if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||
real_model = self.model.model
|
||||
|
||||
if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
with torch.no_grad():
|
||||
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
return real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
||||
@@ -338,20 +350,29 @@ class LoadedModel:
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
return False
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
def model_use_more_vram(self, extra_memory):
|
||||
return self.model.partially_load(self.device, extra_memory)
|
||||
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
|
||||
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
|
||||
def __del__(self):
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
for m in loaded_models:
|
||||
if m.device == device:
|
||||
@@ -366,39 +387,24 @@ def offloaded_memory(loaded_models, device):
|
||||
offloaded_mem += m.model_offloaded_memory()
|
||||
return offloaded_mem
|
||||
|
||||
WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
||||
|
||||
def extra_reserved_memory():
|
||||
return EXTRA_RESERVED_VRAM
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 1.2
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
if len(to_unload) == 0:
|
||||
return True
|
||||
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
|
||||
return unload_weight
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@@ -406,8 +412,8 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
@@ -434,24 +440,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required) + 100 * 1024 * 1024
|
||||
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
||||
if minimum_memory_required is None:
|
||||
minimum_memory_required = extra_mem
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required) + 100 * 1024 * 1024
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
|
||||
models = set(models)
|
||||
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except:
|
||||
@@ -459,51 +464,35 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
|
||||
if loaded is None:
|
||||
loaded.currently_used = True
|
||||
models_to_load.append(loaded)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem < minimum_memory_required:
|
||||
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
||||
models_to_load = free_memory(minimum_memory_required, d)
|
||||
logging.info("{} models unloaded.".format(len(models_to_load)))
|
||||
else:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
if len(models_to_load) == 0:
|
||||
return
|
||||
|
||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
for loaded_model in models_to_load:
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if loaded_model.model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
for i in to_unload:
|
||||
current_loaded_models.pop(i).model.detach(unpatch_all=False)
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_already_loaded:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
@@ -513,7 +502,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||
@@ -525,17 +514,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem > minimum_memory_required:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
return
|
||||
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@@ -549,19 +529,35 @@ def loaded_models(only_currently_used=False):
|
||||
output.append(m.model)
|
||||
return output
|
||||
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
|
||||
if do_gc:
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
if not keep_clone_weights_loaded:
|
||||
to_delete = [i] + to_delete
|
||||
#TODO: find a less fragile way to do this.
|
||||
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
|
||||
to_delete = [i] + to_delete
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
x.model_unload()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
@@ -605,6 +601,12 @@ def maximum_vram_for_weights(device=None):
|
||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if model_params < 0:
|
||||
model_params = 1000000000000000000000
|
||||
if args.fp32_unet:
|
||||
return torch.float32
|
||||
if args.fp64_unet:
|
||||
return torch.float64
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@@ -624,6 +626,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
pass
|
||||
|
||||
if fp8_dtype is not None:
|
||||
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||
return fp8_dtype
|
||||
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if model_params * 2 > free_model_memory:
|
||||
return fp8_dtype
|
||||
@@ -648,7 +653,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
@@ -659,6 +664,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and fp16_supported:
|
||||
return torch.float16
|
||||
@@ -684,6 +690,20 @@ def text_encoder_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||
return offload_device
|
||||
|
||||
if is_device_mps(load_device):
|
||||
return offload_device
|
||||
|
||||
mem_l = get_free_memory(load_device)
|
||||
mem_o = get_free_memory(offload_device)
|
||||
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||
return load_device
|
||||
else:
|
||||
return offload_device
|
||||
|
||||
def text_encoder_dtype(device=None):
|
||||
if args.fp8_e4m3fn_text_enc:
|
||||
return torch.float8_e4m3fn
|
||||
@@ -802,27 +822,21 @@ def force_channels_last():
|
||||
#TODO
|
||||
return False
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
device_supports_cast = True
|
||||
elif tensor.dtype == torch.bfloat16:
|
||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
||||
device_supports_cast = True
|
||||
elif is_intel_xpu():
|
||||
device_supports_cast = True
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
non_blocking = device_should_use_non_blocking(device)
|
||||
|
||||
if device_supports_cast:
|
||||
if copy:
|
||||
if tensor.device == device:
|
||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
@@ -860,7 +874,8 @@ def pytorch_attention_flash_attention():
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
||||
upcast = True
|
||||
except:
|
||||
pass
|
||||
@@ -956,23 +971,23 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
if props.major < 6:
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
#when the model doesn't actually fit on the card
|
||||
#TODO: actually test if GP106 and others have the same type of behavior
|
||||
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
fp16_works = True
|
||||
if WINDOWS or manual_cast:
|
||||
return True
|
||||
else:
|
||||
return False #weird linux behavior where fp32 is faster
|
||||
|
||||
if fp16_works or manual_cast:
|
||||
if manual_cast:
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
@@ -1012,7 +1027,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
@@ -1025,6 +1040,27 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 9:
|
||||
return True
|
||||
if props.major < 8:
|
||||
return False
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,25 @@ import torch
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import math
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
@@ -48,7 +67,7 @@ class CONST:
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
@@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
timesteps = sampling_settings.get("timesteps", 1000)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
if zsnr is None:
|
||||
zsnr = sampling_settings.get("zsnr", False)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
@@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
if zsnr:
|
||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
|
||||
148
comfy/ops.py
148
comfy/ops.py
@@ -18,29 +18,34 @@
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
import comfy.float
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
has_function = s.bias_function is not None
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
bias = s.bias_function(bias)
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
|
||||
has_function = s.weight_function is not None
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
||||
@@ -238,3 +243,124 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
|
||||
if len(input.shape) == 3:
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w = w.t()
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_weight = scale_weight.to(input.device)
|
||||
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input.shape[0], -1)
|
||||
|
||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||
|
||||
return None
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def reset_parameters(self):
|
||||
self.scale_weight = None
|
||||
self.scale_input = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if override_dtype is not None:
|
||||
kwargs['dtype'] = override_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not hasattr(self, 'scale_weight'):
|
||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
|
||||
if not scale_input:
|
||||
self.scale_input = None
|
||||
|
||||
if not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if fp8_matrix_mult:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight
|
||||
else:
|
||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||
|
||||
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
return fp8_ops
|
||||
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
|
||||
return manual_cast
|
||||
|
||||
156
comfy/patcher_extension.py
Normal file
156
comfy/patcher_extension.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
ON_PRE_RUN = "on_pre_run"
|
||||
ON_PREPARE_STATE = "on_prepare_state"
|
||||
ON_APPLY_HOOKS = "on_apply_hooks"
|
||||
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
||||
ON_INJECT_MODEL = "on_inject_model"
|
||||
ON_EJECT_MODEL = "on_eject_model"
|
||||
|
||||
# callbacks dict is in the format:
|
||||
# {"call_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
|
||||
|
||||
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
|
||||
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||
c.append(callback)
|
||||
|
||||
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
c_list = []
|
||||
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||
c_list.extend(callbacks.get(call_type, {}).get(key, []))
|
||||
return c_list
|
||||
|
||||
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
c_list = []
|
||||
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||
for c in callbacks.get(call_type, {}).values():
|
||||
c_list.extend(c)
|
||||
return c_list
|
||||
|
||||
class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
# wrappers dict is in the format:
|
||||
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
|
||||
|
||||
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
|
||||
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
w_list = []
|
||||
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
|
||||
return w_list
|
||||
|
||||
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
w_list = []
|
||||
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||
for w in wrappers.get(wrapper_type, {}).values():
|
||||
w_list.extend(w)
|
||||
return w_list
|
||||
|
||||
class WrapperExecutor:
|
||||
"""Handles call stack of wrappers around a function in an ordered manner."""
|
||||
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
|
||||
# NOTE: class_obj exists so that wrappers surrounding a class method can access
|
||||
# the class instance at runtime via executor.class_obj
|
||||
self.original = original
|
||||
self.class_obj = class_obj
|
||||
self.wrappers = wrappers.copy()
|
||||
self.idx = idx
|
||||
self.is_last = idx == len(wrappers)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls the next wrapper or original function, whichever is appropriate."""
|
||||
new_executor = self._create_next_executor()
|
||||
return new_executor.execute(*args, **kwargs)
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
||||
args = list(args)
|
||||
kwargs = dict(kwargs)
|
||||
if self.is_last:
|
||||
return self.original(*args, **kwargs)
|
||||
return self.wrappers[self.idx](self, *args, **kwargs)
|
||||
|
||||
def _create_next_executor(self) -> 'WrapperExecutor':
|
||||
new_idx = self.idx + 1
|
||||
if new_idx > len(self.wrappers):
|
||||
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
|
||||
if self.class_obj is None:
|
||||
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
||||
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
||||
|
||||
@classmethod
|
||||
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
||||
|
||||
@classmethod
|
||||
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj, wrappers, idx=idx)
|
||||
|
||||
class PatcherInjection:
|
||||
def __init__(self, inject: Callable, eject: Callable):
|
||||
self.inject = inject
|
||||
self.eject = eject
|
||||
|
||||
def copy_nested_dicts(input_dict: dict):
|
||||
new_dict = input_dict.copy()
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
new_dict[key] = copy_nested_dicts(value)
|
||||
elif isinstance(value, list):
|
||||
new_dict[key] = value.copy()
|
||||
return new_dict
|
||||
|
||||
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||
if copy_dict1:
|
||||
merged_dict = copy_nested_dicts(dict1)
|
||||
else:
|
||||
merged_dict = dict1
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict):
|
||||
curr_value = merged_dict.setdefault(key, {})
|
||||
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
||||
elif isinstance(value, list):
|
||||
merged_dict.setdefault(key, []).extend(value)
|
||||
else:
|
||||
merged_dict[key] = value
|
||||
return merged_dict
|
||||
@@ -1,22 +1,61 @@
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
if isinstance(c[model_type], list):
|
||||
models += c[model_type]
|
||||
else:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||
cnets: list[ControlBase] = []
|
||||
for c in cond:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook: comfy.hooks.Hook
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
if 'control' in c:
|
||||
cnets.append(c['control'])
|
||||
|
||||
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
|
||||
if cnet.extra_hooks is not None:
|
||||
_list.append(cnet.extra_hooks)
|
||||
if cnet.previous_controlnet is None:
|
||||
return _list
|
||||
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
|
||||
|
||||
hooks_list = []
|
||||
cnets = set(cnets)
|
||||
for base_cnet in cnets:
|
||||
get_extra_hooks_from_cnet(base_cnet, hooks_list)
|
||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||
if extra_hooks is not None:
|
||||
for hook in extra_hooks.hooks:
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
|
||||
return hooks_dict
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
@@ -26,17 +65,22 @@ def convert_cond(cond):
|
||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
temp["uuid"] = uuid.uuid4()
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets = []
|
||||
cnets: list[ControlBase] = []
|
||||
gligen = []
|
||||
add_models = []
|
||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
||||
|
||||
for k in conds:
|
||||
cnets += get_models_from_cond(conds[k], "control")
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
@@ -47,7 +91,9 @@ def get_additional_models(conds, dtype):
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
||||
models = control_models + gligen + add_models + hook_models
|
||||
|
||||
return models, inference_memory
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
@@ -57,10 +103,11 @@ def cleanup_additional_models(models):
|
||||
m.cleanup()
|
||||
|
||||
|
||||
def prepare_sampling(model, noise_shape, conds):
|
||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||
device = model.load_device
|
||||
real_model = None
|
||||
real_model: 'BaseModel' = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||
@@ -76,3 +123,14 @@ def cleanup_models(conds, models):
|
||||
control_cleanup += get_models_from_cond(conds[k], "control")
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
# check for hooks in conds - if not registered, see if can be applied
|
||||
hooks = {}
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
# register hooks on model/model_options
|
||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
||||
|
||||
@@ -1,12 +1,22 @@
|
||||
from __future__ import annotations
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
import torch
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.samplers
|
||||
import comfy.sampler_helpers
|
||||
import scipy
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
@@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
for c in model_conds:
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
hooks = conds.get('hooks', None)
|
||||
control = conds.get('control', None)
|
||||
|
||||
patches = None
|
||||
@@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
@@ -138,110 +149,184 @@ def cond_cat(c_list):
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
||||
# need to figure out remaining unmasked area for conds
|
||||
default_mults = []
|
||||
for _ in default_conds:
|
||||
default_mults.append(torch.ones_like(x_in))
|
||||
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
|
||||
for lora_hooks, to_run in hooked_to_run.items():
|
||||
for cond_obj, i in to_run:
|
||||
# if no default_cond for cond_type, do nothing
|
||||
if len(default_conds[i]) == 0:
|
||||
continue
|
||||
area: list[int] = cond_obj.area
|
||||
if area is not None:
|
||||
curr_default_mult: torch.Tensor = default_mults[i]
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
|
||||
curr_default_mult -= cond_obj.mult
|
||||
else:
|
||||
default_mults[i] -= cond_obj.mult
|
||||
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
|
||||
for i, mult in enumerate(default_mults):
|
||||
# if no default_cond for cond type, do nothing
|
||||
if len(default_conds[i]) == 0:
|
||||
continue
|
||||
torch.nn.functional.relu(mult, inplace=True)
|
||||
# if mult is all zeros, then don't add default_cond
|
||||
if torch.max(mult) == 0.0:
|
||||
continue
|
||||
|
||||
cond = default_conds[i]
|
||||
for x in cond:
|
||||
# do get_area_and_mult to get all the expected values
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
# replace p's mult with calculated mult
|
||||
p = p._replace(mult=mult)
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
to_run = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
to_run += [(p, i)]
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
@@ -358,11 +443,35 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
||||
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||
|
||||
sigs = []
|
||||
last_t = -1
|
||||
for t in ts:
|
||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||
if t != last_t:
|
||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||
last_t = t
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
|
||||
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
|
||||
if steps == 1:
|
||||
sigma_schedule = [1.0, 0.0]
|
||||
else:
|
||||
if linear_steps is None:
|
||||
linear_steps = steps // 2
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
||||
quadratic_steps = steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||||
const = quadratic_coef * (linear_steps ** 2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||||
for i in range(linear_steps, steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||
|
||||
def get_mask_aabb(masks):
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||
@@ -476,10 +585,15 @@ def calculate_start_end_timesteps(model, conds):
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
# handle clip hook schedule, if needed
|
||||
if 'clip_start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
|
||||
timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
|
||||
else:
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
|
||||
if (timestep_start is not None) or (timestep_end is not None):
|
||||
n = x.copy()
|
||||
@@ -570,8 +684,8 @@ class Sampler:
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
@@ -649,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
if k != kk:
|
||||
create_cond_with_same_area_if_none(conds[kk], c)
|
||||
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook.initialize_timesteps(model)
|
||||
|
||||
for k in conds:
|
||||
pre_run_control(model, conds[k])
|
||||
|
||||
@@ -661,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
return conds
|
||||
|
||||
|
||||
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
||||
# determine which ControlNets have extra_hooks that should be combined with normal hooks
|
||||
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
control: 'ControlBase' = kk['control']
|
||||
extra_hooks = control.get_extra_hooks()
|
||||
if len(extra_hooks) > 0:
|
||||
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||
to_replace = hook_replacement.setdefault((control, hooks), [])
|
||||
to_replace.append(kk)
|
||||
# if nothing to replace, do nothing
|
||||
if len(hook_replacement) == 0:
|
||||
return
|
||||
|
||||
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
|
||||
# on the cond dicts
|
||||
for key, conds_to_modify in hook_replacement.items():
|
||||
control = key[0]
|
||||
hooks = key[1]
|
||||
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
|
||||
# if combined hooks are not None, set as new hooks for all relevant conds
|
||||
if hooks is not None:
|
||||
for cond in conds_to_modify:
|
||||
cond['hooks'] = hooks
|
||||
|
||||
|
||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
hooks_set = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
hooks_set.add(kk.get('hooks', None))
|
||||
return len(hooks_set)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher = model_patcher
|
||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||
self.model_options = model_patcher.model_options
|
||||
self.original_conds = {}
|
||||
self.cfg = 1.0
|
||||
@@ -690,19 +847,17 @@ class CFGGuider:
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_args = {"model_options": self.model_options, "seed":seed}
|
||||
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
||||
|
||||
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
sampler.sample,
|
||||
sampler,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
|
||||
)
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
@@ -713,14 +868,48 @@ class CFGGuider:
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
del self.conds
|
||||
del self.loaded_models
|
||||
return output
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
preprocess_conds_hooks(self.conds)
|
||||
|
||||
try:
|
||||
orig_model_options = self.model_options
|
||||
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
|
||||
orig_hook_mode = self.model_patcher.hook_mode
|
||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.outer_sample,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_options = orig_model_options
|
||||
self.model_patcher.hook_mode = orig_hook_mode
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
return output
|
||||
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
cfg_guider = CFGGuider(model)
|
||||
@@ -729,7 +918,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
@@ -747,6 +936,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
elif scheduler_name == "beta":
|
||||
sigmas = beta_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "linear_quadratic":
|
||||
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
||||
else:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
return sigmas
|
||||
|
||||
357
comfy/sd.py
357
comfy/sd.py
@@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.utils import ProgressBar
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@@ -24,13 +28,19 @@ import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.long_clipl
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.supported_models_base
|
||||
import comfy.taesd.taesd
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@@ -38,6 +48,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
lora = comfy.lora_convert.convert_lora(lora)
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
@@ -62,29 +73,42 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
clip = target.clip
|
||||
tokenizer = target.tokenizer
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
load_device = model_options.get("load_device", model_management.text_encoder_device())
|
||||
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
|
||||
dtype = model_options.get("dtype", None)
|
||||
if dtype is None:
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
params['dtype'] = dtype
|
||||
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
|
||||
params['model_options'] = model_options
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
for dt in self.cond_stage_model.dtypes:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
if params['device'] != offload_device:
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||
self.use_clip_schedule = False
|
||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@@ -92,6 +116,8 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
n.layer_idx = self.layer_idx
|
||||
n.use_clip_schedule = self.use_clip_schedule
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
@@ -103,6 +129,69 @@ class CLIP:
|
||||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||
if self.apply_hooks_to_conds:
|
||||
pooled_dict["hooks"] = self.apply_hooks_to_conds
|
||||
return pooled_dict
|
||||
|
||||
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
|
||||
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
|
||||
all_hooks = self.patcher.forced_hooks
|
||||
if all_hooks is None or not self.use_clip_schedule:
|
||||
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
|
||||
return_pooled = "unprojected" if unprojected else True
|
||||
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
|
||||
cond = pooled_dict.pop("cond")
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
else:
|
||||
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
|
||||
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
if unprojected:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
@@ -120,6 +209,7 @@ class CLIP:
|
||||
if len(o) > 2:
|
||||
for k in o[2]:
|
||||
out[k] = o[2][k]
|
||||
self.add_hooks_to_dict(out)
|
||||
return out
|
||||
|
||||
if return_pooled:
|
||||
@@ -160,6 +250,7 @@ class VAE:
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
@@ -229,9 +320,30 @@ class VAE:
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
|
||||
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
|
||||
self.latent_channels = 12
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -287,6 +399,10 @@ class VAE:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@@ -305,6 +421,7 @@ class VAE:
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
pixel_samples = None
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
@@ -312,38 +429,66 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
if len(samples_in.shape) == 3:
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
else:
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
return output.movedim(1,-1)
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
@@ -364,6 +509,12 @@ class VAE:
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
def spacial_compression_decode(self):
|
||||
try:
|
||||
return self.upscale_ratio[-1]
|
||||
except:
|
||||
return self.upscale_ratio
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
@@ -377,6 +528,8 @@ def load_style_model(ckpt_path):
|
||||
keys = model_data.keys()
|
||||
if "style_embedding" in keys:
|
||||
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||
elif "redux_down.weight" in keys:
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
model.load_state_dict(model_data)
|
||||
@@ -389,11 +542,54 @@ class CLIPType(Enum):
|
||||
STABLE_AUDIO = 4
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
MOCHI = 7
|
||||
LTXV = 8
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
|
||||
class TEModel(Enum):
|
||||
CLIP_L = 1
|
||||
CLIP_H = 2
|
||||
CLIP_G = 3
|
||||
T5_XXL = 4
|
||||
T5_XL = 5
|
||||
T5_BASE = 6
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_G
|
||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_H
|
||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_L
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[-1] == 2048:
|
||||
return TEModel.T5_XL
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
return TEModel.T5_BASE
|
||||
return None
|
||||
|
||||
|
||||
def t5xxl_detect(clip_data):
|
||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||
|
||||
for sd in clip_data:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = state_dicts
|
||||
|
||||
class EmptyClass:
|
||||
pass
|
||||
@@ -408,55 +604,68 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
te_model = detect_te_model(clip_data[0])
|
||||
if te_model == TEModel.CLIP_G:
|
||||
if clip_type == CLIPType.STABLE_CASCADE:
|
||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||
elif clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
elif te_model == TEModel.CLIP_H:
|
||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
dtype_t5 = weight.dtype
|
||||
if weight.shape[-1] == 4096:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
elif te_model == TEModel.T5_XXL:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif weight.shape[-1] == 2048:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
elif te_model == TEModel.T5_XL:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif te_model == TEModel.T5_BASE:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
||||
elif clip_type == CLIPType.FLUX:
|
||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
|
||||
dtype_t5 = None
|
||||
if weight is not None:
|
||||
dtype_t5 = weight.dtype
|
||||
|
||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
|
||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
parameters = 0
|
||||
tokenizer_data = {}
|
||||
for c in clip_data:
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@@ -498,15 +707,19 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return out
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
model = None
|
||||
model_patcher = None
|
||||
clip_target = None
|
||||
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
@@ -515,13 +728,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return None
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
@@ -531,7 +749,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
|
||||
@@ -545,7 +762,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
@@ -567,12 +785,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
@@ -581,6 +800,8 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||
sd = temp_sd
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
|
||||
@@ -607,13 +828,21 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
if dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
@@ -622,24 +851,36 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd, dtype=dtype)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None):
|
||||
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
vae_sd = None
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
|
||||
@@ -75,16 +75,15 @@ class ClipTokenWeightEncoder:
|
||||
return r
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
def __init__(self, device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
@@ -94,8 +93,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.operations = comfy.ops.manual_cast
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
|
||||
if operations is None:
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
if scaled_fp8 is not None:
|
||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
self.max_length = max_length
|
||||
@@ -539,6 +551,7 @@ class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -552,8 +565,12 @@ class SD1Tokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD1CheckpointClipModel(SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if name is not None:
|
||||
@@ -563,7 +580,8 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
|
||||
@@ -3,14 +3,14 @@ import torch
|
||||
import os
|
||||
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -38,10 +39,11 @@ class SDXLTokenizer:
|
||||
return {}
|
||||
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set([dtype])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
||||
cut_to = min(l_out.shape[1], g_out.shape[1])
|
||||
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -66,8 +69,8 @@ class SDXLClipModel(torch.nn.Module):
|
||||
return self.clip_l.load_sd(sd)
|
||||
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options)
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
@@ -79,14 +82,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options)
|
||||
|
||||
@@ -10,6 +10,8 @@ import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -181,7 +183,7 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
memory_usage_factor = 0.7
|
||||
memory_usage_factor = 0.8
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||
@@ -196,6 +198,8 @@ class SDXL(supported_models_base.BASE):
|
||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
|
||||
self.sampling_settings["zsnr"] = True
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
@@ -529,12 +533,11 @@ class SD3(supported_models_base.BASE):
|
||||
clip_l = True
|
||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_g = True
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
if "dtype_t5" in t5_detect:
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
|
||||
|
||||
class StableAudio(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@@ -653,10 +656,17 @@ class Flux(supported_models_base.BASE):
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
|
||||
class FluxInpaint(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
"guidance_embed": True,
|
||||
"in_channels": 96,
|
||||
}
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
class FluxSchnell(Flux):
|
||||
unet_config = {
|
||||
@@ -673,7 +683,63 @@ class FluxSchnell(Flux):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 6.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Mochi
|
||||
|
||||
memory_usage_factor = 2.0 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.GenmoMochi(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
||||
|
||||
class LTXV(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ltxv",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.37,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.LTXV
|
||||
|
||||
memory_usage_factor = 2.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from . import model_base
|
||||
from . import utils
|
||||
@@ -30,6 +48,9 @@ class BASE:
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
@@ -52,6 +73,7 @@ class BASE:
|
||||
self.unet_config = unet_config.copy()
|
||||
self.sampling_settings = self.sampling_settings.copy()
|
||||
self.latent_format = self.latent_format()
|
||||
self.optimizations = self.optimizations.copy()
|
||||
for x in self.unet_extra_config:
|
||||
self.unet_config[x] = self.unet_extra_config[x]
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class PT5XlModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
||||
|
||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||
|
||||
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.model_management
|
||||
from transformers import T5TokenizerFast
|
||||
import torch
|
||||
import os
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class FluxTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -35,11 +32,12 @@ class FluxTokenizer:
|
||||
|
||||
|
||||
class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_t5])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
@@ -64,8 +62,11 @@ class FluxClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def flux_clip(dtype_t5=None):
|
||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
||||
38
comfy/text_encoders/genmo.py
Normal file
38
comfy/text_encoders/genmo.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
|
||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["attention_mask"] = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class MochiTEModel_(MochiT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return MochiTEModel_
|
||||
@@ -7,9 +7,9 @@ import os
|
||||
import torch
|
||||
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
|
||||
class MT5XLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -50,10 +50,10 @@ class HyditTokenizer:
|
||||
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
||||
|
||||
class HyditModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.hydit_clip = HyditBertModel(dtype=dtype)
|
||||
self.mt5xl = MT5XLModel(dtype=dtype)
|
||||
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
|
||||
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
|
||||
25
comfy/text_encoders/long_clipl.json
Normal file
25
comfy/text_encoders/long_clipl.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 49407,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 248,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.24.0",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
30
comfy/text_encoders/long_clipl.py
Normal file
30
comfy/text_encoders/long_clipl.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LongClipModel_(sd1_clip.SDClipModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||
|
||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
|
||||
|
||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||
|
||||
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||
if w is None:
|
||||
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None and w.shape[0] == 248:
|
||||
tokenizer_data = tokenizer_data.copy()
|
||||
model_options = model_options.copy()
|
||||
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||
model_options["clip_l_class"] = LongClipModel_
|
||||
return tokenizer_data, model_options
|
||||
18
comfy/text_encoders/lt.py
Normal file
18
comfy/text_encoders/lt.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.text_encoders.genmo
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
|
||||
|
||||
|
||||
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def ltxv_te(*args, **kwargs):
|
||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||
@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class T5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||
|
||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||
|
||||
@@ -2,13 +2,13 @@ from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user