Compare commits
453 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee9547ba31 | ||
|
|
19a64d6291 | ||
|
|
b486885e08 | ||
|
|
0229228f3f | ||
|
|
1ed75ab30e | ||
|
|
99a1fb6027 | ||
|
|
73e04987f7 | ||
|
|
5388df784a | ||
|
|
26e0ba8f8c | ||
|
|
bc6dac4327 | ||
|
|
f18ebbd316 | ||
|
|
15564688ed | ||
|
|
c6b9c11ef6 | ||
|
|
e44d0ac7f7 | ||
|
|
56bc64f351 | ||
|
|
f7d83b72e0 | ||
|
|
80f07952d2 | ||
|
|
57f330caf9 | ||
|
|
601ff9e3db | ||
|
|
341667c4d5 | ||
|
|
1419dee915 | ||
|
|
da13b6b827 | ||
|
|
c86cd58573 | ||
|
|
b5fe39211a | ||
|
|
e946667216 | ||
|
|
d7969cb070 | ||
|
|
bddb02660c | ||
|
|
418eb7062d | ||
|
|
cac68ca813 | ||
|
|
52c1d933b2 | ||
|
|
3cacd3fca5 | ||
|
|
2dda7c11a3 | ||
|
|
3ad3248ad7 | ||
|
|
c441048a4f | ||
|
|
9f4b181ab3 | ||
|
|
cbbf077593 | ||
|
|
0c04a6ae78 | ||
|
|
416ccc9e45 | ||
|
|
ff2ff02168 | ||
|
|
4c5c4ddeda | ||
|
|
79badea452 | ||
|
|
37e5390f5f | ||
|
|
a4f59bc65e | ||
|
|
ca457f7ba1 | ||
|
|
cd6f615038 | ||
|
|
517669aaa3 | ||
|
|
e4e1bff605 | ||
|
|
d6656b0c0c | ||
|
|
f4cdedea62 | ||
|
|
39b1fc4ccc | ||
|
|
0b25f47bd9 | ||
|
|
bda1482a27 | ||
|
|
19ee5d9d8b | ||
|
|
61b50720d0 | ||
|
|
0f954f34af | ||
|
|
5262901c5c | ||
|
|
cc550d5908 | ||
|
|
6d1a3f7d00 | ||
|
|
1b3a650f19 | ||
|
|
e83063bf24 | ||
|
|
558b7d8b22 | ||
|
|
caf2074773 | ||
|
|
bdf393792d | ||
|
|
4e14032c02 | ||
|
|
59d58b1158 | ||
|
|
563291ee51 | ||
|
|
6c0377f43e | ||
|
|
2cddbf0821 | ||
|
|
60749f345d | ||
|
|
d4426dce7c | ||
|
|
d9d7f3c619 | ||
|
|
fd5dfb812c | ||
|
|
3dfdddcc91 | ||
|
|
5747bc6457 | ||
|
|
5bea1d2ec9 | ||
|
|
5def9fbc83 | ||
|
|
7a7efe8424 | ||
|
|
44db978531 | ||
|
|
1c8d11e48a | ||
|
|
a220d11e6b | ||
|
|
23827ca312 | ||
|
|
0fd4e6c778 | ||
|
|
e2fafe0686 | ||
|
|
6579632201 | ||
|
|
ac2f0523ca | ||
|
|
fbf68c4e52 | ||
|
|
93477f8efe | ||
|
|
8af9a91e0c | ||
|
|
005d2d3a13 | ||
|
|
1e21f4c14e | ||
|
|
9a616b81c1 | ||
|
|
3bed56bb13 | ||
|
|
4e402b11c6 | ||
|
|
48272448ad | ||
|
|
f7695b5f9e | ||
|
|
452179fe4f | ||
|
|
bf9a90a145 | ||
|
|
c1b92b719d | ||
|
|
cdc3b97dd5 | ||
|
|
8d4e06324f | ||
|
|
57e8bf6a9f | ||
|
|
0ee322ec5f | ||
|
|
79d5ceae6e | ||
|
|
2d5b3e0078 | ||
|
|
8e4118c0de | ||
|
|
3fc6ebcdd7 | ||
|
|
20a560eb97 | ||
|
|
82c5308561 | ||
|
|
26fb2c68e8 | ||
|
|
bf2650a80e | ||
|
|
53646e0f32 | ||
|
|
20879c78f9 | ||
|
|
b666539595 | ||
|
|
95d8713482 | ||
|
|
0d4e29f13f | ||
|
|
497db6212f | ||
|
|
24dc581dc3 | ||
|
|
4c82741b54 | ||
|
|
15c39ea757 | ||
|
|
b7143b74ce | ||
|
|
61196d8857 | ||
|
|
b4526d3fc3 | ||
|
|
3d802710e7 | ||
|
|
7126ecffde | ||
|
|
ab885b33ba | ||
|
|
839ed3368e | ||
|
|
6e8cdcd3cb | ||
|
|
e5c3f4b87f | ||
|
|
bc6be6c11e | ||
|
|
94323a26a7 | ||
|
|
5818f6cf51 | ||
|
|
0b734de449 | ||
|
|
5e16f1d24b | ||
|
|
2fd9c1308a | ||
|
|
8f0009aad0 | ||
|
|
41444b5236 | ||
|
|
772e620e32 | ||
|
|
07f6eeaa13 | ||
|
|
22535d0589 | ||
|
|
898615122f | ||
|
|
156a28786b | ||
|
|
f498d855ba | ||
|
|
b699a15062 | ||
|
|
9cc90ee3eb | ||
|
|
9a0a5d32ee | ||
|
|
d9f90965c8 | ||
|
|
41886af138 | ||
|
|
22a1d7ce78 | ||
|
|
4ac401af2b | ||
|
|
5fb59c8475 | ||
|
|
122c9ca1ce | ||
|
|
3b9a6cf2b1 | ||
|
|
3748e7ef7a | ||
|
|
8ebf2d8831 | ||
|
|
a72d152b0c | ||
|
|
eb476e6ea9 | ||
|
|
2d28b0b479 | ||
|
|
8b275ce5be | ||
|
|
2a18e98ccf | ||
|
|
8a5281006f | ||
|
|
bdeb1c171c | ||
|
|
9c1ed58ef2 | ||
|
|
8b90e50979 | ||
|
|
6ee066a14f | ||
|
|
dd5b57e3d7 | ||
|
|
75a818c720 | ||
|
|
2865f913f7 | ||
|
|
b49616f951 | ||
|
|
5e29e7a488 | ||
|
|
8afb97cd3f | ||
|
|
69694f40b3 | ||
|
|
c49025f01b | ||
|
|
696672905f | ||
|
|
6c9dbde7de | ||
|
|
ee8abf0cff | ||
|
|
fabf449feb | ||
|
|
cc9cf6d1bd | ||
|
|
1c8286a44b | ||
|
|
1af4a47fd1 | ||
|
|
f2aaa0a475 | ||
|
|
daa1565b93 | ||
|
|
09fdb2b269 | ||
|
|
65a8659182 | ||
|
|
770ab200f2 | ||
|
|
954683d0db | ||
|
|
30c0c81351 | ||
|
|
13b0ff8a6f | ||
|
|
c320801187 | ||
|
|
c0b0cfaeec | ||
|
|
669d9e4c67 | ||
|
|
9ee0a6553a | ||
|
|
5cbb01bc2f | ||
|
|
c3ffbae067 | ||
|
|
d605677b33 | ||
|
|
ce759b7db6 | ||
|
|
52810907e2 | ||
|
|
af8cf79a2d | ||
|
|
66b0961a46 | ||
|
|
754597c8a9 | ||
|
|
915fdb5745 | ||
|
|
5a8a48931a | ||
|
|
8ce2a1052c | ||
|
|
f82314fcfc | ||
|
|
0075c6d096 | ||
|
|
83ca891118 | ||
|
|
f9f9faface | ||
|
|
471cd3eace | ||
|
|
a68bbafddb | ||
|
|
73e3a9e676 | ||
|
|
518c0dc2fe | ||
|
|
ce0542e10b | ||
|
|
8473019d40 | ||
|
|
89f15894dd | ||
|
|
67158994a4 | ||
|
|
7390ff3b1e | ||
|
|
0bedfb26af | ||
|
|
f71cfd2687 | ||
|
|
c695c4af7f | ||
|
|
0dbba9f751 | ||
|
|
f584758271 | ||
|
|
95b7cf9bbe | ||
|
|
191a0d56b4 | ||
|
|
3c60ecd7a8 | ||
|
|
7ae6626723 | ||
|
|
6632365e16 | ||
|
|
ad07796777 | ||
|
|
1b80895285 | ||
|
|
5f9d5a244b | ||
|
|
14eba07acd | ||
|
|
4b2f0d9413 | ||
|
|
25eac1d780 | ||
|
|
e38c94228b | ||
|
|
203942c8b2 | ||
|
|
3c72c89a52 | ||
|
|
614377abd6 | ||
|
|
8dfa0cc552 | ||
|
|
e5ecdfdd2d | ||
|
|
7d29fbf74b | ||
|
|
2c641e64ad | ||
|
|
7d2467e830 | ||
|
|
6f021d8aa0 | ||
|
|
d854ed0bcf | ||
|
|
abcd006b8c | ||
|
|
d985d1d7dc | ||
|
|
d1cdf51e1b | ||
|
|
b4626ab93e | ||
|
|
a9e459c2a4 | ||
|
|
3bb4dec720 | ||
|
|
8733191563 | ||
|
|
83b01f960a | ||
|
|
d72e871cfa | ||
|
|
037c3159b6 | ||
|
|
bdd4a22a2e | ||
|
|
fdf37566ef | ||
|
|
08c8968482 | ||
|
|
479a427a48 | ||
|
|
3a0eeee320 | ||
|
|
447da7ea86 | ||
|
|
9c41bc8d10 | ||
|
|
6ad0ddbae4 | ||
|
|
a55142f904 | ||
|
|
5718ef69bb | ||
|
|
13ecf10a92 | ||
|
|
7a415f47a9 | ||
|
|
89fa2fca24 | ||
|
|
364b69e931 | ||
|
|
dc96a1ae19 | ||
|
|
2d810b081e | ||
|
|
9f7e9f0547 | ||
|
|
a355f38ecc | ||
|
|
38c69080c7 | ||
|
|
70a708d726 | ||
|
|
e7d4782736 | ||
|
|
3326bdfd4e | ||
|
|
68bb885d22 | ||
|
|
ad66f7c7d8 | ||
|
|
de8e8e3b0d | ||
|
|
a1e71cfad1 | ||
|
|
0bfc7cc998 | ||
|
|
7183fd1665 | ||
|
|
254838f23c | ||
|
|
0b7dfa986d | ||
|
|
d514bb38ee | ||
|
|
0849c80e2a | ||
|
|
56e8f5e4fd | ||
|
|
e813abbb2c | ||
|
|
5e68a4ce67 | ||
|
|
ca08597670 | ||
|
|
f48e390032 | ||
|
|
369a6dd2c4 | ||
|
|
b3ce8fb9fd | ||
|
|
cf80d28689 | ||
|
|
6fb44c4b7c | ||
|
|
d2247c1e61 | ||
|
|
cb12ad7049 | ||
|
|
f6b7194f64 | ||
|
|
7c6eb4fb29 | ||
|
|
b962db9952 | ||
|
|
d0b7ab88ba | ||
|
|
405b529545 | ||
|
|
9d720187f1 | ||
|
|
d247bc5a9c | ||
|
|
9f4daca9d9 | ||
|
|
b5d0f2a908 | ||
|
|
e760bf5c40 | ||
|
|
36c83cdbba | ||
|
|
81778a7feb | ||
|
|
bc94662b31 | ||
|
|
9fa8faa44a | ||
|
|
9a7444e39f | ||
|
|
54fca4a218 | ||
|
|
cd4955367e | ||
|
|
8354203d95 | ||
|
|
e0b41243b4 | ||
|
|
619263d4a6 | ||
|
|
e3b0402bb7 | ||
|
|
967867d48c | ||
|
|
cbaac71bf5 | ||
|
|
3ab3516e46 | ||
|
|
9c5fca75f4 | ||
|
|
a5da4d0b3e | ||
|
|
32a60a7bac | ||
|
|
bb52934ba4 | ||
|
|
8aabd7c8c0 | ||
|
|
a09b29ca11 | ||
|
|
9bfee68773 | ||
|
|
ea77750759 | ||
|
|
c27ebeb1c2 | ||
|
|
0c7c98a965 | ||
|
|
dc2eb75b85 | ||
|
|
fa34efe3bd | ||
|
|
5cbaa9e07c | ||
|
|
c7427375ee | ||
|
|
22d1241a50 | ||
|
|
f04229b84d | ||
|
|
f067ad15d1 | ||
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 | ||
|
|
f1c2301697 | ||
|
|
8d31a6632f | ||
|
|
b643eae08b | ||
|
|
baa6b4dc36 | ||
|
|
d4aeefc297 | ||
|
|
587e7ca654 | ||
|
|
c90459eba0 | ||
|
|
04278afb10 | ||
|
|
935ae153e1 | ||
|
|
e91662e784 | ||
|
|
63fafaef45 | ||
|
|
ec28cd9136 | ||
|
|
6eb5d64522 | ||
|
|
10a79e9898 | ||
|
|
ea3f39bd69 | ||
|
|
b33cd61070 | ||
|
|
34eda0f853 | ||
|
|
d31e226650 | ||
|
|
b79fd7d92c | ||
|
|
38c22e631a | ||
|
|
6bbdcd28ae | ||
|
|
ab130001a8 | ||
|
|
ca4b8f30e0 | ||
|
|
70b84058c1 | ||
|
|
2ca8f6e23d | ||
|
|
7985ff88b9 | ||
|
|
c6812947e9 | ||
|
|
9230f65823 | ||
|
|
6ab1e6fd4a | ||
|
|
07dcbc3a3e | ||
|
|
8ae23d8e80 | ||
|
|
7df42b9a23 | ||
|
|
5d8bbb7281 | ||
|
|
2c1d2375d6 | ||
|
|
64ccb3c7e3 | ||
|
|
9465b23432 | ||
|
|
bb4416dd5b | ||
|
|
c0b0da264b | ||
|
|
c26ca27207 | ||
|
|
7c6bb84016 | ||
|
|
c54d3ed5e6 | ||
|
|
c7ee4b37a1 | ||
|
|
7b70b266d8 | ||
|
|
8f60d093ba | ||
|
|
dafbe321d2 | ||
|
|
5f84ea63e8 | ||
|
|
843a7ff70c | ||
|
|
a60620dcea | ||
|
|
015f73dc49 | ||
|
|
904bf58e7d | ||
|
|
5f50263088 | ||
|
|
5e806f555d | ||
|
|
f07e5bb522 | ||
|
|
03ec517afb | ||
|
|
f257fc999f | ||
|
|
bb50e69839 | ||
|
|
510f3438c1 | ||
|
|
ea63b1c092 | ||
|
|
9953f22fce | ||
|
|
d1a6bd6845 | ||
|
|
83dbac28eb | ||
|
|
538cb068bc | ||
|
|
1b3eee672c | ||
|
|
5a69f84c3c | ||
|
|
9eee470244 | ||
|
|
045377ea89 | ||
|
|
4d341b78e8 | ||
|
|
6138f92084 | ||
|
|
be0726c1ed | ||
|
|
766ae119a8 | ||
|
|
fc90ceb6ba | ||
|
|
4506ddc86a | ||
|
|
20ace7c853 | ||
|
|
b29b3b86c5 | ||
|
|
22ec02afc0 | ||
|
|
39f114c44b | ||
|
|
6730f3e1a3 | ||
|
|
73332160c8 | ||
|
|
2622c55aff | ||
|
|
1beb348ee2 | ||
|
|
9aa39e743c | ||
|
|
d31df04c8a | ||
|
|
e68763f40c | ||
|
|
310ad09258 | ||
|
|
4f7a3cb6fb | ||
|
|
bb222ceddb | ||
|
|
14af129c55 | ||
|
|
fca42836f2 | ||
|
|
858d51f91a | ||
|
|
cd5017c1c9 | ||
|
|
83f343146a | ||
|
|
b021cf67c7 | ||
|
|
1770fc77ed | ||
|
|
05a9f3faa1 | ||
|
|
86c5970ac0 | ||
|
|
bfc214f434 | ||
|
|
3f5939add6 | ||
|
|
5960f946a9 | ||
|
|
5cfe38f41c | ||
|
|
0f9c2a7822 | ||
|
|
153d0a8142 | ||
|
|
ab4dd19b91 | ||
|
|
f1d6cef71c | ||
|
|
33fb282d5c | ||
|
|
50bf66e5c4 | ||
|
|
e60e19b175 | ||
|
|
a5af64d3ce | ||
|
|
3e52e0364c | ||
|
|
34608de2e9 | ||
|
|
39fb74c5bd | ||
|
|
74e124f4d7 | ||
|
|
a562c17e8a | ||
|
|
5942c17d55 | ||
|
|
c032b11e07 |
@@ -28,17 +28,17 @@ def pull(repo, remote_name='origin', branch='master'):
|
||||
|
||||
if repo.index.conflicts is not None:
|
||||
for conflict in repo.index.conflicts:
|
||||
print('Conflicts found in:', conflict[0].path)
|
||||
print('Conflicts found in:', conflict[0].path) # noqa: T201
|
||||
raise AssertionError('Conflicts, ahhhhh!!')
|
||||
|
||||
user = repo.default_signature
|
||||
tree = repo.index.write_tree()
|
||||
commit = repo.create_commit('HEAD',
|
||||
user,
|
||||
user,
|
||||
'Merge!',
|
||||
tree,
|
||||
[repo.head.target, remote_master_id])
|
||||
repo.create_commit('HEAD',
|
||||
user,
|
||||
user,
|
||||
'Merge!',
|
||||
tree,
|
||||
[repo.head.target, remote_master_id])
|
||||
# We need to do this or git CLI will think we are still merging.
|
||||
repo.state_cleanup()
|
||||
else:
|
||||
@@ -49,18 +49,18 @@ repo_path = str(sys.argv[1])
|
||||
repo = pygit2.Repository(repo_path)
|
||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||
try:
|
||||
print("stashing current changes")
|
||||
print("stashing current changes") # noqa: T201
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash")
|
||||
print("nothing to stash") # noqa: T201
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
||||
except:
|
||||
pass
|
||||
|
||||
print("checking out master branch")
|
||||
print("checking out master branch") # noqa: T201
|
||||
branch = repo.lookup_branch('master')
|
||||
if branch is None:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
@@ -72,10 +72,29 @@ else:
|
||||
ref = repo.lookup_reference(branch.name)
|
||||
repo.checkout(ref)
|
||||
|
||||
print("pulling latest changes")
|
||||
print("pulling latest changes") # noqa: T201
|
||||
pull(repo)
|
||||
|
||||
print("Done!")
|
||||
if "--stable" in sys.argv:
|
||||
def latest_tag(repo):
|
||||
versions = []
|
||||
for k in repo.references:
|
||||
try:
|
||||
prefix = "refs/tags/v"
|
||||
if k.startswith(prefix):
|
||||
version = list(map(int, k[len(prefix):].split(".")))
|
||||
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
|
||||
except:
|
||||
pass
|
||||
versions.sort()
|
||||
if len(versions) > 0:
|
||||
return versions[-1][1]
|
||||
return None
|
||||
latest_tag = latest_tag(repo)
|
||||
if latest_tag is not None:
|
||||
repo.checkout(latest_tag)
|
||||
|
||||
print("Done!") # noqa: T201
|
||||
|
||||
self_update = True
|
||||
if len(sys.argv) > 2:
|
||||
@@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
|
||||
shutil.copy(repo_req_path, req_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
|
||||
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
|
||||
|
||||
try:
|
||||
if not file_size(stable_update_script_to) > 10:
|
||||
shutil.copy(stable_update_script, stable_update_script_to)
|
||||
except:
|
||||
pass
|
||||
|
||||
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
@@ -0,0 +1,8 @@
|
||||
@echo off
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
|
||||
if exist update_new.py (
|
||||
move /y update_new.py update.py
|
||||
echo Running updater again since it got updated.
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
|
||||
)
|
||||
if "%~1"=="" pause
|
||||
@@ -14,7 +14,7 @@ run_cpu.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
|
||||
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
@@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
||||
pause
|
||||
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
/web/assets/** linguist-generated
|
||||
/web/** linguist-vendored
|
||||
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: ComfyUI Frontend Issues
|
||||
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
|
||||
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
|
||||
- name: ComfyUI Matrix Space
|
||||
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
||||
|
||||
2
.github/workflows/pullrequest-ci-run.yml
vendored
2
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
|
||||
@@ -3,8 +3,8 @@ name: Python Linting
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
ruff:
|
||||
name: Run Ruff
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -16,8 +16,8 @@ jobs:
|
||||
with:
|
||||
python-version: 3.x
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
- name: Install Ruff
|
||||
run: pip install ruff
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
|
||||
- name: Run Ruff
|
||||
run: ruff check .
|
||||
6
.github/workflows/stable-release.yml
vendored
6
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
default: "124"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
|
||||
|
||||
jobs:
|
||||
|
||||
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: 'Close stale issues'
|
||||
on:
|
||||
schedule:
|
||||
# Run daily at 430 am PT
|
||||
- cron: '30 11 * * *'
|
||||
permissions:
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||
days-before-stale: 30
|
||||
days-before-close: 7
|
||||
stale-issue-label: 'Stale'
|
||||
only-labels: 'User Support'
|
||||
exempt-all-assignees: true
|
||||
exempt-all-milestones: true
|
||||
76
.github/workflows/test-browser.yml
vendored
76
.github/workflows/test-browser.yml
vendored
@@ -1,76 +0,0 @@
|
||||
# This is a temporary action during frontend TS migration.
|
||||
# This file should be removed after TS migration is completed.
|
||||
# The browser test is here to ensure TS repo is working the same way as the
|
||||
# current JS code.
|
||||
# If you are adding UI feature, please sync your changes to the TS repo:
|
||||
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
|
||||
name: Playwright Browser Tests CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- name: Checkout ComfyUI_frontend
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "huchenlei/ComfyUI_frontend"
|
||||
path: "ComfyUI_frontend"
|
||||
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: lts/*
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install wait-for-it
|
||||
working-directory: ComfyUI
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||
working-directory: ComfyUI
|
||||
- name: Install ComfyUI_frontend dependencies
|
||||
run: |
|
||||
npm ci
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Install Playwright Browsers
|
||||
run: npx playwright install --with-deps
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Run Playwright tests
|
||||
run: npx playwright test
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
working-directory: ComfyUI
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: ComfyUI_frontend/playwright-report/
|
||||
retention-days: 30
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: console-output
|
||||
path: ComfyUI/console_output.log
|
||||
retention-days: 30
|
||||
53
.github/workflows/test-ci.yml
vendored
53
.github/workflows/test-ci.yml
vendored
@@ -20,7 +20,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux, windows]
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
@@ -31,9 +32,9 @@ jobs:
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
flags: ""
|
||||
# - os: windows
|
||||
# runner_label: [self-hosted, Windows]
|
||||
# flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
@@ -45,28 +46,28 @@ jobs:
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-win-nightly:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
# test-win-nightly:
|
||||
# strategy:
|
||||
# fail-fast: true
|
||||
# matrix:
|
||||
# os: [windows]
|
||||
# python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# cuda_version: ["12.1"]
|
||||
# torch_version: ["nightly"]
|
||||
# include:
|
||||
# - os: windows
|
||||
# runner_label: [self-hosted, Windows]
|
||||
# flags: ""
|
||||
# runs-on: ${{ matrix.runner_label }}
|
||||
# steps:
|
||||
# - name: Test Workflows
|
||||
# uses: comfy-org/comfy-action@main
|
||||
# with:
|
||||
# os: ${{ matrix.os }}
|
||||
# python_version: ${{ matrix.python_version }}
|
||||
# torch_version: ${{ matrix.torch_version }}
|
||||
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
# comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-unix-nightly:
|
||||
strategy:
|
||||
|
||||
45
.github/workflows/test-launch.yml
vendored
Normal file
45
.github/workflows/test-launch.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: Test server launches without errors
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install wait-for-it
|
||||
working-directory: ComfyUI
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
working-directory: ComfyUI
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: console-output
|
||||
path: ComfyUI/console_output.log
|
||||
retention-days: 30
|
||||
@@ -1,29 +1,29 @@
|
||||
name: Tests CI
|
||||
name: Unit Tests
|
||||
|
||||
on: [push, pull_request]
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
node-version: 18
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
- name: Run Tests
|
||||
run: |
|
||||
npm ci
|
||||
npm run test:generate
|
||||
npm test -- --verbose
|
||||
working-directory: ./tests-ui
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
@@ -12,7 +12,7 @@ on:
|
||||
description: 'extra dependencies'
|
||||
required: false
|
||||
type: string
|
||||
default: "\"numpy<2\""
|
||||
default: ""
|
||||
cu:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
@@ -23,13 +23,13 @@ on:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -67,6 +67,7 @@ jobs:
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||
|
||||
echo "call update_comfyui.bat nopause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
|
||||
@@ -13,13 +13,13 @@ on:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "7"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ extra_model_paths.yaml
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
!/web/extensions/core/
|
||||
|
||||
24
CODEOWNERS
24
CODEOWNERS
@@ -1 +1,23 @@
|
||||
* @comfyanonymous
|
||||
# Admins
|
||||
* @comfyanonymous
|
||||
|
||||
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
# Frontend assets
|
||||
/web/ @huchenlei @webfiltered @pythongosssss
|
||||
|
||||
# Extra nodes
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||
|
||||
211
README.md
211
README.md
@@ -1,8 +1,35 @@
|
||||
ComfyUI
|
||||
=======
|
||||
The most powerful and modular stable diffusion GUI and backend.
|
||||
-----------
|
||||

|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular diffusion model GUI and backend.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||
[![Matrix][matrix-shield]][matrix-url]
|
||||
<br>
|
||||
[![][github-release-shield]][github-release-link]
|
||||
[![][github-release-date-shield]][github-release-link]
|
||||
[![][github-downloads-shield]][github-downloads-link]
|
||||
[![][github-downloads-latest-shield]][github-downloads-link]
|
||||
|
||||
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
|
||||
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
|
||||
[website-url]: https://www.comfy.org/
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
|
||||
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@@ -11,8 +38,21 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- Image Models
|
||||
- SD1.x, SD2.x,
|
||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||
- Pixart Alpha and Sigma
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@@ -32,9 +72,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
@@ -46,32 +83,39 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
| Keybind | Explanation |
|
||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Alt + `+` | Canvas Zoom in |
|
||||
| Alt + `-` | Canvas Zoom out |
|
||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| `Ctrl` + `Enter` | Queue up current graph for generation |
|
||||
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
|
||||
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
|
||||
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
|
||||
| `Ctrl` + `S` | Save workflow |
|
||||
| `Ctrl` + `O` | Load workflow |
|
||||
| `Ctrl` + `A` | Select all nodes |
|
||||
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
|
||||
| `Ctrl` + `M` | Mute/unmute selected nodes |
|
||||
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| `Delete`/`Backspace` | Delete selected nodes |
|
||||
| `Ctrl` + `Backspace` | Delete the current graph |
|
||||
| `Space` | Move the canvas around when held and moving the cursor |
|
||||
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
|
||||
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
|
||||
| `Ctrl` + `D` | Load default graph |
|
||||
| `Alt` + `+` | Canvas Zoom in |
|
||||
| `Alt` + `-` | Canvas Zoom out |
|
||||
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| `P` | Pin/Unpin selected nodes |
|
||||
| `Ctrl` + `G` | Group selected nodes |
|
||||
| `Q` | Toggle visibility of the queue |
|
||||
| `H` | Toggle visibility of history |
|
||||
| `R` | Refresh graph |
|
||||
| `F` | Show/Hide menu |
|
||||
| `.` | Fit view to selection (Whole graph when nothing is selected) |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| `Shift` + Drag | Move multiple wires at once |
|
||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
`Ctrl` can also be replaced with `Cmd` instead for macOS users
|
||||
|
||||
# Installing
|
||||
|
||||
@@ -95,6 +139,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
@@ -105,21 +151,45 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch nightly, use the following command:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||
|
||||
```
|
||||
conda install libuv
|
||||
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||
```
|
||||
|
||||
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@@ -139,17 +209,6 @@ After this you should have everything installed and can proceed to running Comfy
|
||||
|
||||
### Others:
|
||||
|
||||
#### Intel GPUs
|
||||
|
||||
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
|
||||
|
||||
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
|
||||
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
|
||||
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
|
||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
|
||||
#### Apple Mac silicon
|
||||
|
||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||
@@ -177,6 +236,14 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
|
||||
|
||||
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
Only parts of the graph that have an output with all the correct inputs will be executed.
|
||||
@@ -200,7 +267,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
||||
|
||||
Use ```--preview-method auto``` to enable previews.
|
||||
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||
|
||||
## How to use TLS/SSL?
|
||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||
@@ -216,9 +283,49 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
|
||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
|
||||
## Frontend Development
|
||||
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||
|
||||
### Reporting Issues and Requesting Features
|
||||
|
||||
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
|
||||
|
||||
### Using the Latest Frontend
|
||||
|
||||
The new frontend is now the default for ComfyUI. However, please note:
|
||||
|
||||
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||
2. Daily releases are available in the separate frontend repository.
|
||||
|
||||
To use the most up-to-date frontend version:
|
||||
|
||||
1. For the latest daily release, launch ComfyUI with this command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@latest
|
||||
```
|
||||
|
||||
2. For a specific version, replace `latest` with the desired version number:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||
```
|
||||
|
||||
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||
```
|
||||
|
||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||
|
||||
# QA
|
||||
|
||||
### Which GPU should I buy for this?
|
||||
|
||||
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
|
||||
|
||||
|
||||
0
api_server/__init__.py
Normal file
0
api_server/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
3
api_server/routes/internal/README.md
Normal file
3
api_server/routes/internal/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# ComfyUI Internal Routes
|
||||
|
||||
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
||||
0
api_server/routes/internal/__init__.py
Normal file
0
api_server/routes/internal/__init__.py
Normal file
75
api_server/routes/internal/internal_routes.py
Normal file
75
api_server/routes/internal/internal_routes.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
|
||||
class InternalRoutes:
|
||||
'''
|
||||
The top level web router for internal routes: /internal/*
|
||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||
Check README.md for more information.
|
||||
'''
|
||||
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
async def list_files(request):
|
||||
directory_key = request.query.get('directory', '')
|
||||
try:
|
||||
file_list = self.file_service.list_files(directory_key)
|
||||
return web.json_response({"files": file_list})
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
|
||||
@self.routes.get('/logs/raw')
|
||||
async def get_raw_logs(request):
|
||||
self.terminal_service.update_size()
|
||||
return web.json_response({
|
||||
"entries": list(app.logger.get_logs()),
|
||||
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
||||
})
|
||||
|
||||
@self.routes.patch('/logs/subscribe')
|
||||
async def subscribe_logs(request):
|
||||
json_data = await request.json()
|
||||
client_id = json_data["clientId"]
|
||||
enabled = json_data["enabled"]
|
||||
if enabled:
|
||||
self.terminal_service.subscribe(client_id)
|
||||
else:
|
||||
self.terminal_service.unsubscribe(client_id)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@self.routes.get('/folder_paths')
|
||||
async def get_folder_paths(request):
|
||||
response = {}
|
||||
for key in folder_names_and_paths:
|
||||
response[key] = folder_names_and_paths[key][0]
|
||||
return web.json_response(response)
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
self.setup_routes()
|
||||
self._app.add_routes(self.routes)
|
||||
return self._app
|
||||
0
api_server/services/__init__.py
Normal file
0
api_server/services/__init__.py
Normal file
13
api_server/services/file_service.py
Normal file
13
api_server/services/file_service.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Dict, List, Optional
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||
|
||||
class FileService:
|
||||
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||
|
||||
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||
if directory_key not in self.allowed_directories:
|
||||
raise ValueError("Invalid directory key")
|
||||
directory_path: str = self.allowed_directories[directory_key]
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
||||
60
api_server/services/terminal_service.py
Normal file
60
api_server/services/terminal_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from app.logger import on_flush
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class TerminalService:
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
self.cols = None
|
||||
self.rows = None
|
||||
self.subscriptions = set()
|
||||
on_flush(self.send_messages)
|
||||
|
||||
def get_terminal_size(self):
|
||||
try:
|
||||
size = os.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
try:
|
||||
size = shutil.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
return (80, 24) # fallback to 80x24
|
||||
|
||||
def update_size(self):
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
return {"cols": self.cols, "rows": self.rows}
|
||||
|
||||
return None
|
||||
|
||||
def subscribe(self, client_id):
|
||||
self.subscriptions.add(client_id)
|
||||
|
||||
def unsubscribe(self, client_id):
|
||||
self.subscriptions.discard(client_id)
|
||||
|
||||
def send_messages(self, entries):
|
||||
if not len(entries) or not len(self.subscriptions):
|
||||
return
|
||||
|
||||
new_size = self.update_size()
|
||||
|
||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||
if client_id not in self.server.sockets:
|
||||
# Automatically unsub if the socket has disconnected
|
||||
self.unsubscribe(client_id)
|
||||
continue
|
||||
|
||||
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
||||
42
api_server/utils/file_operations.py
Normal file
42
api_server/utils/file_operations.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import List, Union, TypedDict, Literal
|
||||
from typing_extensions import TypeGuard
|
||||
class FileInfo(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
type: Literal["file"]
|
||||
size: int
|
||||
|
||||
class DirectoryInfo(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
type: Literal["directory"]
|
||||
|
||||
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
||||
|
||||
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
||||
return item["type"] == "file"
|
||||
|
||||
class FileSystemOperations:
|
||||
@staticmethod
|
||||
def walk_directory(directory: str) -> List[FileSystemItem]:
|
||||
file_list: List[FileSystemItem] = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for name in files:
|
||||
file_path = os.path.join(root, name)
|
||||
relative_path = os.path.relpath(file_path, directory)
|
||||
file_list.append({
|
||||
"name": name,
|
||||
"path": relative_path,
|
||||
"type": "file",
|
||||
"size": os.path.getsize(file_path)
|
||||
})
|
||||
for name in dirs:
|
||||
dir_path = os.path.join(root, name)
|
||||
relative_path = os.path.relpath(dir_path, directory)
|
||||
file_list.append({
|
||||
"name": name,
|
||||
"path": relative_path,
|
||||
"type": "directory"
|
||||
})
|
||||
return file_list
|
||||
@@ -8,7 +8,7 @@ import zipfile
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
@@ -132,12 +132,13 @@ class FrontendManager:
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
@@ -150,7 +151,16 @@ class FrontendManager:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
provider = FrontEndProvider(repo_owner, repo_name)
|
||||
|
||||
if version.startswith("v"):
|
||||
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||
if os.path.exists(expected_path):
|
||||
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||
return expected_path
|
||||
|
||||
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||
|
||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
semantic_version = release["tag_name"].lstrip("v")
|
||||
@@ -158,15 +168,21 @@ class FrontendManager:
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||
)
|
||||
if not os.path.exists(web_root):
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
try:
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
finally:
|
||||
# Clean up the directory if it is empty, i.e. the download failed
|
||||
if not os.listdir(web_root):
|
||||
os.rmdir(web_root)
|
||||
|
||||
return web_root
|
||||
|
||||
@classmethod
|
||||
|
||||
73
app/logger.py
Normal file
73
app/logger.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
logs = None
|
||||
stdout_interceptor = None
|
||||
stderr_interceptor = None
|
||||
|
||||
|
||||
class LogInterceptor(io.TextIOWrapper):
|
||||
def __init__(self, stream, *args, **kwargs):
|
||||
buffer = stream.buffer
|
||||
encoding = stream.encoding
|
||||
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
||||
self._lock = threading.Lock()
|
||||
self._flush_callbacks = []
|
||||
self._logs_since_flush = []
|
||||
|
||||
def write(self, data):
|
||||
entry = {"t": datetime.now().isoformat(), "m": data}
|
||||
with self._lock:
|
||||
self._logs_since_flush.append(entry)
|
||||
|
||||
# Simple handling for cr to overwrite the last output if it isnt a full line
|
||||
# else logs just get full of progress messages
|
||||
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
||||
logs.pop()
|
||||
logs.append(entry)
|
||||
super().write(data)
|
||||
|
||||
def flush(self):
|
||||
super().flush()
|
||||
for cb in self._flush_callbacks:
|
||||
cb(self._logs_since_flush)
|
||||
self._logs_since_flush = []
|
||||
|
||||
def on_flush(self, callback):
|
||||
self._flush_callbacks.append(callback)
|
||||
|
||||
|
||||
def get_logs():
|
||||
return logs
|
||||
|
||||
|
||||
def on_flush(callback):
|
||||
if stdout_interceptor is not None:
|
||||
stdout_interceptor.on_flush(callback)
|
||||
if stderr_interceptor is not None:
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
|
||||
# Override output streams and log to buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
|
||||
global stdout_interceptor
|
||||
global stderr_interceptor
|
||||
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
||||
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
||||
|
||||
# Setup default global logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(log_level)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(stream_handler)
|
||||
184
app/model_manager.py
Normal file
184
app/model_manager.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import folder_paths
|
||||
import glob
|
||||
import comfy.utils
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
||||
|
||||
|
||||
class ModelFileManager:
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||
|
||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||
return self.cache.get(key, default)
|
||||
|
||||
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
||||
self.cache[key] = value
|
||||
|
||||
def clear_cache(self):
|
||||
self.cache.clear()
|
||||
|
||||
def add_routes(self, routes):
|
||||
# NOTE: This is an experiment to replace `/models`
|
||||
@routes.get("/experiment/models")
|
||||
async def get_model_folders(request):
|
||||
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||
folder_black_list = ["configs", "custom_nodes"]
|
||||
output_folders: list[dict] = []
|
||||
for folder in model_types:
|
||||
if folder in folder_black_list:
|
||||
continue
|
||||
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
||||
return web.json_response(output_folders)
|
||||
|
||||
# NOTE: This is an experiment to replace `/models/{folder}`
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
|
||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||
async def get_model_preview(request):
|
||||
folder_name = request.match_info.get("folder", None)
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
folder = folders[0][path_index]
|
||||
full_filename = os.path.join(folder, filename)
|
||||
|
||||
previews = self.get_model_previews(full_filename)
|
||||
default_preview = previews[0] if len(previews) > 0 else None
|
||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||
return web.Response(status=404)
|
||||
|
||||
try:
|
||||
with Image.open(default_preview) as img:
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format="WEBP")
|
||||
img_bytes.seek(0)
|
||||
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
||||
except:
|
||||
return web.Response(status=404)
|
||||
|
||||
def get_model_file_list(self, folder_name: str):
|
||||
folder_name = map_legacy(folder_name)
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
output_list: list[dict] = []
|
||||
|
||||
for index, folder in enumerate(folders[0]):
|
||||
if not os.path.isdir(folder):
|
||||
continue
|
||||
out = self.cache_model_file_list_(folder)
|
||||
if out is None:
|
||||
out = self.recursive_search_models_(folder, index)
|
||||
self.set_cache(folder, out)
|
||||
output_list.extend(out[0])
|
||||
|
||||
return output_list
|
||||
|
||||
def cache_model_file_list_(self, folder: str):
|
||||
model_file_list_cache = self.get_cache(folder)
|
||||
|
||||
if model_file_list_cache is None:
|
||||
return None
|
||||
if not os.path.isdir(folder):
|
||||
return None
|
||||
if os.path.getmtime(folder) != model_file_list_cache[1]:
|
||||
return None
|
||||
for x in model_file_list_cache[1]:
|
||||
time_modified = model_file_list_cache[1][x]
|
||||
folder = x
|
||||
if os.path.getmtime(folder) != time_modified:
|
||||
return None
|
||||
|
||||
return model_file_list_cache
|
||||
|
||||
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}, time.perf_counter()
|
||||
|
||||
excluded_dir_names = [".git"]
|
||||
# TODO use settings
|
||||
include_hidden_files = False
|
||||
|
||||
result: list[str] = []
|
||||
dirs: dict[str, float] = {}
|
||||
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
if not include_hidden_files:
|
||||
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
||||
filenames = [f for f in filenames if not f.startswith(".")]
|
||||
|
||||
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
||||
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
path: str = os.path.join(dirpath, d)
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
return []
|
||||
|
||||
basename = os.path.splitext(filepath)[0]
|
||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||
image_files = filter_files_content_types(match_files, "image")
|
||||
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
||||
safetensors_metadata = {}
|
||||
|
||||
result: list[str | BytesIO] = []
|
||||
|
||||
for filename in image_files:
|
||||
_basename = os.path.splitext(filename)[0]
|
||||
if _basename == basename:
|
||||
result.append(filename)
|
||||
if _basename == f"{basename}.preview":
|
||||
result.append(filename)
|
||||
|
||||
if safetensors_file:
|
||||
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
||||
if header:
|
||||
safetensors_metadata = json.loads(header)
|
||||
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
||||
if safetensors_images:
|
||||
safetensors_images = json.loads(safetensors_images)
|
||||
for image in safetensors_images:
|
||||
result.append(BytesIO(base64.b64decode(image)))
|
||||
|
||||
return result
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.clear_cache()
|
||||
@@ -1,38 +1,58 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
from folder_paths import user_directory
|
||||
import folder_paths
|
||||
from .app_settings import AppSettings
|
||||
from typing import TypedDict
|
||||
|
||||
default_user = "default"
|
||||
users_file = os.path.join(user_directory, "users.json")
|
||||
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
}
|
||||
|
||||
|
||||
class UserManager():
|
||||
def __init__(self):
|
||||
global user_directory
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
os.mkdir(user_directory)
|
||||
os.makedirs(user_directory, exist_ok=True)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(users_file):
|
||||
with open(users_file) as f:
|
||||
if os.path.isfile(self.get_users_file()):
|
||||
with open(self.get_users_file()) as f:
|
||||
self.users = json.load(f)
|
||||
else:
|
||||
self.users = {}
|
||||
else:
|
||||
self.users = {"default": "default"}
|
||||
|
||||
def get_users_file(self):
|
||||
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||
|
||||
def get_request_user_id(self, request):
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
@@ -44,7 +64,7 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
global user_directory
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
@@ -59,6 +79,10 @@ class UserManager():
|
||||
return None
|
||||
|
||||
if file is not None:
|
||||
# Check if filename is url encoded
|
||||
if "%" in file:
|
||||
file = parse.unquote(file)
|
||||
|
||||
# prevent leaving /{type}/{user}
|
||||
path = os.path.abspath(os.path.join(user_root, file))
|
||||
if os.path.commonpath((user_root, path)) != user_root:
|
||||
@@ -80,8 +104,7 @@ class UserManager():
|
||||
|
||||
self.users[user_id] = name
|
||||
|
||||
global users_file
|
||||
with open(users_file, "w") as f:
|
||||
with open(self.get_users_file(), "w") as f:
|
||||
json.dump(self.users, f)
|
||||
|
||||
return user_id
|
||||
@@ -112,25 +135,65 @@ class UserManager():
|
||||
|
||||
@routes.get("/userdata")
|
||||
async def listuserdata(request):
|
||||
"""
|
||||
List user data files in a specified directory.
|
||||
|
||||
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||
full file information, and path splitting.
|
||||
|
||||
Query Parameters:
|
||||
- dir (required): The directory to list files from.
|
||||
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||
|
||||
Returns:
|
||||
- 400: If 'dir' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 404: If the requested directory does not exist.
|
||||
- 200: JSON response with the list of files or file information.
|
||||
|
||||
The response format depends on the query parameters:
|
||||
- Default: List of relative file paths.
|
||||
- full_info=true: List of dictionaries with file details.
|
||||
- split=true (and full_info=false): List of lists, each containing path components.
|
||||
"""
|
||||
directory = request.rel_url.query.get('dir', '')
|
||||
if not directory:
|
||||
return web.Response(status=400)
|
||||
|
||||
return web.Response(status=400, text="Directory not provided")
|
||||
|
||||
path = self.get_request_user_filepath(request, directory)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
return web.Response(status=403, text="Invalid directory")
|
||||
|
||||
if not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
return web.Response(status=404, text="Directory not found")
|
||||
|
||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||
results = glob.glob(os.path.join(
|
||||
glob.escape(path), '**/*'), recursive=recurse)
|
||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
||||
|
||||
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
if split_path:
|
||||
results = [[x] + x.split(os.sep) for x in results]
|
||||
|
||||
# Use different patterns based on whether we're recursing or not
|
||||
if recurse:
|
||||
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||
else:
|
||||
pattern = os.path.join(glob.escape(path), '*')
|
||||
|
||||
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
|
||||
if full_info:
|
||||
return get_file_info(full_path, path)
|
||||
|
||||
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
|
||||
if split_path:
|
||||
return [rel_path] + rel_path.split('/')
|
||||
|
||||
return rel_path
|
||||
|
||||
results = [
|
||||
process_full_path(full_path)
|
||||
for full_path in glob.glob(pattern, recursive=recurse)
|
||||
if os.path.isfile(full_path)
|
||||
]
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@@ -138,14 +201,14 @@ class UserManager():
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
return web.Response(status=400)
|
||||
|
||||
|
||||
path = self.get_request_user_filepath(request, file)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
|
||||
if check_exists and not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
|
||||
return path
|
||||
|
||||
@routes.get("/userdata/{file}")
|
||||
@@ -153,25 +216,56 @@ class UserManager():
|
||||
path = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
|
||||
return web.FileResponse(path)
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
"""
|
||||
Upload or update a user data file.
|
||||
|
||||
This endpoint handles file uploads to a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Path Parameters:
|
||||
- file: The target file path (URL encoded if necessary).
|
||||
|
||||
Returns:
|
||||
- 400: If 'file' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 409: If overwrite=false and the file already exists.
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
|
||||
The request body should contain the raw file content to be written.
|
||||
"""
|
||||
path = get_user_data_path(request)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409)
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(path, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(path, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
@routes.delete("/userdata/{file}")
|
||||
@@ -181,25 +275,56 @@ class UserManager():
|
||||
return path
|
||||
|
||||
os.remove(path)
|
||||
|
||||
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.post("/userdata/{file}/move/{dest}")
|
||||
async def move_userdata(request):
|
||||
"""
|
||||
Move or rename a user data file.
|
||||
|
||||
This endpoint handles moving or renaming files within a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Path Parameters:
|
||||
- file: The source file path (URL encoded if necessary)
|
||||
- dest: The destination file path (URL encoded if necessary)
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Returns:
|
||||
- 400: If either 'file' or 'dest' parameter is missing
|
||||
- 403: If either requested path is not allowed
|
||||
- 404: If the source file does not exist
|
||||
- 409: If overwrite=false and the destination file already exists
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
"""
|
||||
source = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(source, str):
|
||||
return source
|
||||
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409)
|
||||
|
||||
print(f"moving '{source}' -> '{dest}'")
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
logging.info(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(dest, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(dest, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
@@ -2,11 +2,9 @@
|
||||
#and modified
|
||||
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ..ldm.modules.diffusionmodules.util import (
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
@@ -162,7 +160,6 @@ class ControlNet(nn.Module):
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
@@ -415,7 +412,6 @@ class ControlNet(nn.Module):
|
||||
out_output = []
|
||||
out_middle = []
|
||||
|
||||
hs = []
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
120
comfy/cldm/dit_embedder.py
Normal file
120
comfy/cldm/dit_embedder.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||
|
||||
|
||||
class ControlNetEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
adm_in_channels: int,
|
||||
num_layers: int,
|
||||
main_model_double: int,
|
||||
double_y_emb: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.main_model_double = main_model_double
|
||||
self.dtype = dtype
|
||||
self.hidden_size = num_attention_heads * attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=pos_embed_max_size is None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.double_y_emb = double_y_emb
|
||||
if self.double_y_emb:
|
||||
self.orig_y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
self.y_embedder = VectorEmbedder(
|
||||
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
DismantledBlock(
|
||||
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
)
|
||||
|
||||
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
||||
# TODO double check this logic when 8b
|
||||
self.use_y_embedder = True
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.pos_embed_input = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
hint = None,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
x_shape = list(x.shape)
|
||||
x = self.x_embedder(x)
|
||||
if not self.double_y_emb:
|
||||
h = (x_shape[-2] + 1) // self.patch_size
|
||||
w = (x_shape[-1] + 1) // self.patch_size
|
||||
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
||||
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
if self.double_y_emb:
|
||||
y = self.orig_y_embedder(y)
|
||||
y = self.y_embedder(y)
|
||||
c = c + y
|
||||
|
||||
x = x + self.pos_embed_input(hint)
|
||||
|
||||
block_out = ()
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
||||
for i in range(len(self.transformer_blocks)):
|
||||
out = self.transformer_blocks[i](x, c)
|
||||
if not self.double_y_emb:
|
||||
x = out
|
||||
block_out += (self.controlnet_blocks[i](out),) * repeat
|
||||
|
||||
return {"output": block_out}
|
||||
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
|
||||
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks = None,
|
||||
control_latent_channels = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
for _ in range(len(self.joint_blocks)):
|
||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||
|
||||
if control_latent_channels is None:
|
||||
control_latent_channels = self.in_channels
|
||||
|
||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||
None,
|
||||
self.patch_size,
|
||||
self.in_channels,
|
||||
control_latent_channels,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=False,
|
||||
|
||||
@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||
@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
||||
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
|
||||
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
|
||||
@@ -82,7 +84,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
@@ -92,10 +95,17 @@ class LatentPreviewMethod(enum.Enum):
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
@@ -112,10 +122,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
||||
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
@@ -126,7 +140,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
@@ -161,6 +175,8 @@ parser.add_argument(
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
@@ -171,10 +187,3 @@ if args.windows_standalone_build:
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
import logging
|
||||
logging_level = logging.INFO
|
||||
if args.verbose:
|
||||
logging_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||
|
||||
@@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
|
||||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
@@ -88,10 +89,11 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
num_positions = config_dict["max_position_embeddings"]
|
||||
self.eos_token_id = config_dict["eos_token_id"]
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
@@ -123,7 +125,6 @@ class CLIPTextModel(torch.nn.Module):
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
if model_type == "siglip_vision_model":
|
||||
self.class_embedding = None
|
||||
patch_bias = True
|
||||
else:
|
||||
num_patches = num_patches + 1
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
patch_bias = False
|
||||
|
||||
self.patch_embedding = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
bias=patch_bias,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
if self.class_embedding is not None:
|
||||
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
||||
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
@@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.output_layernorm = False
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
||||
@@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
|
||||
x = self.pre_layrnorm(x)
|
||||
#TODO: attention_mask?
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
if self.output_layernorm:
|
||||
x = self.post_layernorm(x)
|
||||
pooled_output = x
|
||||
else:
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
if "projection_dim" in config_dict:
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
|
||||
@@ -16,13 +16,18 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
@@ -35,6 +40,8 @@ class ClipVisionModel():
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -49,9 +56,9 @@ class ClipVisionModel():
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
@@ -109,8 +118,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
t = sd.pop(k)
|
||||
del t
|
||||
sd.pop(k)
|
||||
return clip
|
||||
|
||||
def load(ckpt_path):
|
||||
|
||||
13
comfy/clip_vision_siglip_384.json
Normal file
13
comfy/clip_vision_siglip_384.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 384,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
43
comfy/comfy_types/README.md
Normal file
43
comfy/comfy_types/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Comfy Typing
|
||||
## Type hinting for ComfyUI Node development
|
||||
|
||||
This module provides type hinting and concrete convenience types for node developers.
|
||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||
|
||||
```python
|
||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {"required": {}}
|
||||
```
|
||||
|
||||
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
|
||||
|
||||
# Types
|
||||
A few primary types are documented below. More complete information is available via the docstrings on each type.
|
||||
|
||||
## `IO`
|
||||
|
||||
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
|
||||
|
||||
- `ANY`: `"*"`
|
||||
- `NUMBER`: `"FLOAT,INT"`
|
||||
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
|
||||
|
||||
## `ComfyNodeABC`
|
||||
|
||||
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
|
||||
|
||||
### Type hinting for `INPUT_TYPES`
|
||||
|
||||

|
||||
|
||||
### `INPUT_TYPES` return dict
|
||||
|
||||

|
||||
|
||||
### Options for individual inputs
|
||||
|
||||

|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
|
||||
class UnetApplyFunction(Protocol):
|
||||
@@ -30,3 +31,15 @@ class UnetParams(TypedDict):
|
||||
|
||||
|
||||
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"UnetWrapperFunction",
|
||||
UnetApplyConds.__name__,
|
||||
UnetParams.__name__,
|
||||
UnetApplyFunction.__name__,
|
||||
IO.__name__,
|
||||
InputTypeDict.__name__,
|
||||
ComfyNodeABC.__name__,
|
||||
CheckLazyMixin.__name__,
|
||||
]
|
||||
28
comfy/comfy_types/examples/example_nodes.py
Normal file
28
comfy/comfy_types/examples/example_nodes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
"""An example node that just adds 1 to an input integer.
|
||||
|
||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
||||
* Not intended for use in ComfyUI.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
CATEGORY = "examples"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"input_int": (IO.INT, {"defaultInput": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("input_plus_one",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, input_int: int):
|
||||
return (input_int + 1,)
|
||||
BIN
comfy/comfy_types/examples/input_options.png
Normal file
BIN
comfy/comfy_types/examples/input_options.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
BIN
comfy/comfy_types/examples/input_types.png
Normal file
BIN
comfy/comfy_types/examples/input_types.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
274
comfy/comfy_types/node_typing.py
Normal file
274
comfy/comfy_types/node_typing.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class IO(StrEnum):
|
||||
"""Node input/output data types.
|
||||
|
||||
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
|
||||
"""
|
||||
|
||||
STRING = "STRING"
|
||||
IMAGE = "IMAGE"
|
||||
MASK = "MASK"
|
||||
LATENT = "LATENT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
GUIDER = "GUIDER"
|
||||
NOISE = "NOISE"
|
||||
CLIP = "CLIP"
|
||||
CONTROL_NET = "CONTROL_NET"
|
||||
VAE = "VAE"
|
||||
MODEL = "MODEL"
|
||||
CLIP_VISION = "CLIP_VISION"
|
||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||
STYLE_MODEL = "STYLE_MODEL"
|
||||
GLIGEN = "GLIGEN"
|
||||
UPSCALE_MODEL = "UPSCALE_MODEL"
|
||||
AUDIO = "AUDIO"
|
||||
WEBCAM = "WEBCAM"
|
||||
POINT = "POINT"
|
||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
|
||||
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
|
||||
"""
|
||||
NUMBER = "FLOAT,INT"
|
||||
"""A float or an int - could be either"""
|
||||
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
|
||||
"""Could be any of: string, float, int, or bool"""
|
||||
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
|
||||
"""
|
||||
|
||||
default: bool | str | float | int | list | tuple
|
||||
"""The default value of the widget"""
|
||||
defaultInput: bool
|
||||
"""Defaults to an input slot rather than a widget"""
|
||||
forceInput: bool
|
||||
"""`defaultInput` and also don't allow converting to a widget"""
|
||||
lazy: bool
|
||||
"""Declares that this input uses lazy evaluation"""
|
||||
rawLink: bool
|
||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||
tooltip: str
|
||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: float
|
||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||
max: float
|
||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||
step: float
|
||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||
round: float
|
||||
"""Floats are rounded by this value (``FLOAT``)"""
|
||||
# class InputTypeBoolean(InputTypeOptions):
|
||||
# default: bool
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
multiline: bool
|
||||
"""Use a multiline text box (``STRING``)"""
|
||||
placeholder: str
|
||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||
# Deprecated:
|
||||
# defaultVal: str
|
||||
dynamicPrompts: bool
|
||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||
|
||||
node_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
unique_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
prompt: Literal["PROMPT"]
|
||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||
dynprompt: Literal["DYNPROMPT"]
|
||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||
|
||||
|
||||
class InputTypeDict(TypedDict):
|
||||
"""Provides type hinting for node INPUT_TYPES.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
|
||||
"""
|
||||
|
||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes all inputs that must be connected for the node to execute."""
|
||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes inputs which do not need to be connected."""
|
||||
hidden: HiddenInputTypeDict
|
||||
"""Offers advanced functionality and server-client communication.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
"""
|
||||
|
||||
|
||||
class ComfyNodeABC(ABC):
|
||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
|
||||
"""
|
||||
|
||||
DESCRIPTION: str
|
||||
"""Node description, shown as a tooltip when hovering over the node.
|
||||
|
||||
Usage::
|
||||
|
||||
# Explicitly define the description
|
||||
DESCRIPTION = "Example description here."
|
||||
|
||||
# Use the docstring of the node class.
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
"""
|
||||
CATEGORY: str
|
||||
"""The category of the node, as per the "Add Node" menu.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
|
||||
"""
|
||||
EXPERIMENTAL: bool
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
"""Defines node inputs.
|
||||
|
||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
|
||||
"""
|
||||
return {"required": {}}
|
||||
|
||||
OUTPUT_NODE: bool
|
||||
"""Flags this node as an output node, causing any inputs it requires to be executed.
|
||||
|
||||
If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
From the docs:
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
|
||||
"""
|
||||
INPUT_IS_LIST: bool
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
|
||||
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||
|
||||
From the docs:
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
|
||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||
|
||||
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
|
||||
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
|
||||
|
||||
From the docs:
|
||||
|
||||
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
|
||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
||||
specifying which outputs which should be so treated.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
"""A tuple representing the outputs of this node.
|
||||
|
||||
Usage::
|
||||
|
||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
|
||||
"""
|
||||
|
||||
|
||||
class CheckLazyMixin:
|
||||
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
|
||||
|
||||
def check_lazy_status(self, **kwargs) -> list[str]:
|
||||
"""Returns a list of input names that should be evaluated.
|
||||
|
||||
This basic mixin impl. requires all inputs.
|
||||
|
||||
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
|
||||
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
|
||||
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
return need
|
||||
@@ -34,6 +34,12 @@ import comfy.t2i_adapter.adapter
|
||||
import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
@@ -58,7 +64,7 @@ class StrengthType(Enum):
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
@@ -70,20 +76,26 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
self.extra_concat_orig = []
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
if self.latent_format is not None:
|
||||
if vae is None:
|
||||
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||
self.vae = vae
|
||||
self.extra_concat_orig = extra_concat.copy()
|
||||
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||
return self
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
@@ -98,9 +110,9 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
@@ -108,6 +120,14 @@ class ControlBase:
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
out.append(self.extra_hooks)
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
@@ -121,6 +141,10 @@ class ControlBase:
|
||||
c.vae = self.vae
|
||||
c.extra_conds = self.extra_conds.copy()
|
||||
c.strength_type = self.strength_type
|
||||
c.concat_mask = self.concat_mask
|
||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
|
||||
c.preprocess_image = self.preprocess_image
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@@ -146,7 +170,7 @@ class ControlBase:
|
||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||
x *= (self.strength ** float(len(control_output) - i))
|
||||
|
||||
if x.dtype != output_dtype:
|
||||
if output_dtype is not None and x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
@@ -173,8 +197,8 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
||||
super().__init__(device)
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
|
||||
super().__init__()
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
@@ -187,11 +211,13 @@ class ControlNet(ControlBase):
|
||||
self.latent_format = latent_format
|
||||
self.extra_conds += extra_conds
|
||||
self.strength_type = strength_type
|
||||
self.concat_mask = concat_mask
|
||||
self.preprocess_image = preprocess_image
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
@@ -204,7 +230,6 @@ class ControlNet(ControlBase):
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
@@ -212,14 +237,26 @@ class ControlNet(ControlBase):
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
if self.latent_format is not None:
|
||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||
if len(self.extra_concat_orig) > 0:
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = c.to(self.cond_hint.device)
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
@@ -234,7 +271,7 @@ class ControlNet(ControlBase):
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
return self.control_merge(control, control_prev, output_dtype=None)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
@@ -260,7 +297,6 @@ class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
@@ -318,8 +354,8 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
||||
ControlBase.__init__(self, device)
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
ControlBase.__init__(self)
|
||||
self.control_weights = control_weights
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.extra_conds += ["y"]
|
||||
@@ -345,7 +381,6 @@ class ControlLora(ControlNet):
|
||||
self.control_model.to(comfy.model_management.get_torch_device())
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
weight = sd[k]
|
||||
@@ -375,21 +410,28 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def controlnet_config(sd):
|
||||
def controlnet_config(sd, model_options={}):
|
||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
unet_dtype = model_options.get("dtype", None)
|
||||
if unet_dtype is None:
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
controlnet_config = model_config.unet_config
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
@@ -401,25 +443,108 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
|
||||
def load_controlnet_mmdit(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
concat_mask = False
|
||||
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||
if control_latent_channels == 17: #inpaint controlnet
|
||||
concat_mask = True
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_controlnet_sd35(sd, model_options={}):
|
||||
control_type = -1
|
||||
if "control_type" in sd:
|
||||
control_type = round(sd.pop("control_type").item())
|
||||
|
||||
# blur_cnet = control_type == 0
|
||||
canny_cnet = control_type == 1
|
||||
depth_cnet = control_type == 2
|
||||
|
||||
new_sd = {}
|
||||
for k in comfy.utils.MMDIT_MAP_BASIC:
|
||||
if k[1] in sd:
|
||||
new_sd[k[0]] = sd.pop(k[1])
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
sd = new_sd
|
||||
|
||||
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
|
||||
depth = y_emb_shape[0] // 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
head_dim = hidden_size // num_heads
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
|
||||
patch_size=2,
|
||||
in_chans=16,
|
||||
num_layers=num_blocks,
|
||||
main_model_double=depth,
|
||||
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
|
||||
attention_head_dim=head_dim,
|
||||
num_attention_heads=num_heads,
|
||||
adm_in_channels=2048,
|
||||
device=offload_device,
|
||||
dtype=unet_dtype,
|
||||
operations=operations)
|
||||
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
preprocess_image = lambda a: a
|
||||
if canny_cnet:
|
||||
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
|
||||
elif depth_cnet:
|
||||
preprocess_image = lambda a: 1.0 - a
|
||||
|
||||
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||
return control
|
||||
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||
|
||||
latent_format = comfy.latent_formats.SDXL()
|
||||
@@ -427,13 +552,49 @@ def load_controlnet_hunyuandit(controlnet_data):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
num_union_modes = 0
|
||||
union_cnet = "controlnet_mode_embedder.weight"
|
||||
if union_cnet in new_sd:
|
||||
num_union_modes = new_sd[union_cnet].shape[0]
|
||||
|
||||
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||
concat_mask = False
|
||||
if control_latent_channels == 17:
|
||||
concat_mask = True
|
||||
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Flux()
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
|
||||
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
controlnet_data = state_dict
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
return load_controlnet_hunyuandit(controlnet_data)
|
||||
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
return ControlLora(controlnet_data, model_options=model_options)
|
||||
|
||||
controlnet_config = None
|
||||
supported_inference_dtypes = None
|
||||
@@ -488,8 +649,18 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
@@ -501,26 +672,38 @@ def load_controlnet(ckpt_path, model=None):
|
||||
elif key in controlnet_data:
|
||||
prefix = ""
|
||||
else:
|
||||
net = load_t2i_adapter(controlnet_data)
|
||||
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||
if net is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
logging.error("error could not detect control model type.")
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
controlnet_config = model_config.unet_config
|
||||
|
||||
unet_dtype = model_options.get("dtype", None)
|
||||
if unet_dtype is None:
|
||||
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||
|
||||
if supported_inference_dtypes is None:
|
||||
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
if supported_inference_dtypes is None:
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
else:
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||
|
||||
controlnet_config["operations"] = operations
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
@@ -554,22 +737,32 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||
if "global_average_pooling" not in model_options:
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
model_options["global_average_pooling"] = True
|
||||
|
||||
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||
if cnet is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
return cnet
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||
super().__init__(device)
|
||||
super().__init__()
|
||||
self.t2i_model = t2i_model
|
||||
self.channels_in = channels_in
|
||||
self.control_input = None
|
||||
self.compression_ratio = compression_ratio
|
||||
self.upscale_algorithm = upscale_algorithm
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.device = device
|
||||
|
||||
def scale_image_to(self, width, height):
|
||||
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||
@@ -577,10 +770,10 @@ class T2IAdapter(ControlBase):
|
||||
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
||||
return width, height
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
@@ -617,7 +810,7 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
@@ -628,7 +821,7 @@ def load_t2i_adapter(t2i_data):
|
||||
for i in range(4):
|
||||
for j in range(2):
|
||||
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
|
||||
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
|
||||
prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
|
||||
prefix_replace["adapter."] = ""
|
||||
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
|
||||
keys = t2i_data.keys()
|
||||
|
||||
@@ -157,16 +157,23 @@ vae_conversion_map_attn = [
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
def reshape_weight_for_sd(w, conv3d=False):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if conv3d:
|
||||
return w.reshape(*w.shape, 1, 1, 1)
|
||||
else:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
conv3d = False
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in vae_conversion_map:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
if v.endswith(".conv.weight"):
|
||||
if not conv3d and vae_state_dict[k].ndim == 5:
|
||||
conv3d = True
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "attentions" in k:
|
||||
@@ -179,7 +186,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
logging.debug(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
if text_encoder2_path is not None:
|
||||
text_encoder_paths.append(text_encoder2_path)
|
||||
|
||||
unet = comfy.sd.load_unet(unet_path)
|
||||
unet = comfy.sd.load_diffusion_model(unet_path)
|
||||
|
||||
clip = None
|
||||
if output_clip:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#code taken from: https://github.com/wl-zhao/UniPC and modified
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import logging
|
||||
|
||||
from tqdm.auto import trange, tqdm
|
||||
from tqdm.auto import trange
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
@@ -16,7 +16,7 @@ class NoiseScheduleVP:
|
||||
continuous_beta_0=0.1,
|
||||
continuous_beta_1=20.,
|
||||
):
|
||||
"""Create a wrapper class for the forward SDE (VP type).
|
||||
r"""Create a wrapper class for the forward SDE (VP type).
|
||||
|
||||
***
|
||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||
@@ -475,7 +475,7 @@ class UniPC:
|
||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
@@ -519,7 +519,6 @@ class UniPC:
|
||||
A_p = C_inv_p
|
||||
|
||||
if use_corrector:
|
||||
print('using corrector')
|
||||
C_inv = torch.linalg.inv(C)
|
||||
A_c = C_inv
|
||||
|
||||
@@ -704,7 +703,6 @@ class UniPC:
|
||||
):
|
||||
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
# t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
device = x.device
|
||||
steps = len(timesteps) - 1
|
||||
if method == 'multistep':
|
||||
assert steps >= order
|
||||
|
||||
67
comfy/float.py
Normal file
67
comfy/float.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
|
||||
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||
|
||||
#Not 100% sure about this
|
||||
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||
elif dtype == torch.float8_e5m2:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
||||
else:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
x = x.half()
|
||||
sign = torch.sign(x)
|
||||
abs_x = x.abs()
|
||||
sign = torch.where(abs_x == 0, 0, sign)
|
||||
|
||||
# Combine exponent calculation and clamping
|
||||
exponent = torch.clamp(
|
||||
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||
0, 2**EXPONENT_BITS - 1
|
||||
)
|
||||
|
||||
# Combine mantissa calculation and rounding
|
||||
normal_mask = ~(exponent == 0)
|
||||
|
||||
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||
|
||||
sign *= torch.where(
|
||||
normal_mask,
|
||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
|
||||
inf = torch.finfo(dtype)
|
||||
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
||||
return sign
|
||||
|
||||
|
||||
|
||||
def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float32:
|
||||
return value.to(dtype=torch.float32)
|
||||
if dtype == torch.float16:
|
||||
return value.to(dtype=torch.float16)
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
for i in range(0, value.shape[0], slice_size):
|
||||
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
|
||||
691
comfy/hooks.py
Normal file
691
comfy/hooks.py
Normal file
@@ -0,0 +1,691 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import enum
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.sd import CLIP
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
class EnumHookMode(enum.Enum):
|
||||
MinVram = "minvram"
|
||||
MaxSpeed = "maxspeed"
|
||||
|
||||
class EnumHookType(enum.Enum):
|
||||
Weight = "weight"
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Callbacks = "callbacks"
|
||||
Wrappers = "wrappers"
|
||||
SetInjections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
Model = "model"
|
||||
Clip = "clip"
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
|
||||
# NOTE: this is an example of how the should_register function should look
|
||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return True
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||
self.hook_type = hook_type
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
self.hook_id = hook_id
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
self.reset()
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: Hook = subtype()
|
||||
c.hook_type = self.hook_type
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return self.custom_should_register(self, model, model_options, target, registered)
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: 'Hook'):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hook_ref)
|
||||
|
||||
class WeightHook(Hook):
|
||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||
super().__init__(hook_type=EnumHookType.Weight)
|
||||
self.weights: dict = None
|
||||
self.weights_clip: dict = None
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
return self._strength_model * self.strength
|
||||
|
||||
@property
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
weights = None
|
||||
if target == EnumWeightTarget.Model:
|
||||
strength = self._strength_model
|
||||
else:
|
||||
strength = self._strength_clip
|
||||
|
||||
if self.need_weight_init:
|
||||
key_map = {}
|
||||
if target == EnumWeightTarget.Model:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Model:
|
||||
weights = self.weights
|
||||
else:
|
||||
weights = self.weights_clip
|
||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WeightHook = super().clone(subtype)
|
||||
c.weights = self.weights
|
||||
c.weights_clip = self.weights_clip
|
||||
c.need_weight_init = self.need_weight_init
|
||||
c._strength_model = self._strength_model
|
||||
c._strength_clip = self._strength_clip
|
||||
return c
|
||||
|
||||
class PatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.Patch)
|
||||
self.patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: PatchHook = super().clone(subtype)
|
||||
c.patches = self.patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: ObjectPatchHook = super().clone(subtype)
|
||||
c.object_patches = self.object_patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class AddModelsHook(Hook):
|
||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
self.models = models
|
||||
self.append_when_same = True
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddModelsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
return True
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: SetInjectionsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.injections = self.injections.copy() if self.injections else self.injections
|
||||
return c
|
||||
|
||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
# TODO: add functionality
|
||||
pass
|
||||
|
||||
class HookGroup:
|
||||
def __init__(self):
|
||||
self.hooks: list[Hook] = []
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
|
||||
def contains(self, hook: Hook):
|
||||
return hook in self.hooks
|
||||
|
||||
def clone(self):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone_and_combine(self, other: 'HookGroup'):
|
||||
c = self.clone()
|
||||
if other is not None:
|
||||
for hook in other.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
hook_kf = hook_kf.clone()
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
def get_dict_repr(self):
|
||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||
for hook in self.hooks:
|
||||
with_type = d.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
# hooks should not have their schedules in a list of tuples
|
||||
all_ranges: list[tuple[float, float]] = []
|
||||
for range_kfs in scheduled_hooks.values():
|
||||
for t_range, keyframe in range_kfs:
|
||||
all_ranges.append(t_range)
|
||||
# turn list of ranges into boundaries
|
||||
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
|
||||
boundaries_set.add(0.0)
|
||||
boundaries = sorted(boundaries_set)
|
||||
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
|
||||
# with real ranges defined, give appropriate hooks w/ keyframes for each range
|
||||
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
|
||||
for t_range in real_ranges:
|
||||
hooks_schedule = []
|
||||
for hook, val in scheduled_hooks.items():
|
||||
keyframe = None
|
||||
# check if is a keyframe that works for the current t_range
|
||||
for stored_range, stored_kf in val:
|
||||
# if stored start is less than current end, then fits - give it assigned keyframe
|
||||
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
|
||||
keyframe = stored_kf
|
||||
break
|
||||
hooks_schedule.append((hook, keyframe))
|
||||
scheduled_keyframes.append((t_range, hooks_schedule))
|
||||
return scheduled_keyframes
|
||||
|
||||
def reset(self):
|
||||
for hook in self.hooks:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
actual.append(group)
|
||||
if len(actual) < require_count:
|
||||
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
|
||||
# if no hooks, then return None
|
||||
if len(actual) == 0:
|
||||
return None
|
||||
# if only 1 hook, just return itself without cloning
|
||||
elif len(actual) == 1:
|
||||
return actual[0]
|
||||
final_hook: HookGroup = None
|
||||
for hook in actual:
|
||||
if final_hook is None:
|
||||
final_hook = hook.clone()
|
||||
else:
|
||||
final_hook = final_hook.clone_and_combine(hook)
|
||||
return final_hook
|
||||
|
||||
|
||||
class HookKeyframe:
|
||||
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
|
||||
self.strength = strength
|
||||
# scheduling
|
||||
self.start_percent = float(start_percent)
|
||||
self.start_t = 999999999.9
|
||||
self.guarantee_steps = guarantee_steps
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframe(strength=self.strength,
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
c.start_t = self.start_t
|
||||
return c
|
||||
|
||||
class HookKeyframeGroup:
|
||||
def __init__(self):
|
||||
self.keyframes: list[HookKeyframe] = []
|
||||
self._current_keyframe: HookKeyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self._curr_t = -1.
|
||||
|
||||
# properties shadow those of HookWeightsKeyframe
|
||||
@property
|
||||
def strength(self):
|
||||
if self._current_keyframe is not None:
|
||||
return self._current_keyframe.strength
|
||||
return 1.0
|
||||
|
||||
def reset(self):
|
||||
self._current_keyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self.curr_t = -1.
|
||||
self._set_first_as_current()
|
||||
|
||||
def add(self, keyframe: HookKeyframe):
|
||||
# add to end of list, then sort
|
||||
self.keyframes.append(keyframe)
|
||||
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
|
||||
self._set_first_as_current()
|
||||
|
||||
def _set_first_as_current(self):
|
||||
if len(self.keyframes) > 0:
|
||||
self._current_keyframe = self.keyframes[0]
|
||||
else:
|
||||
self._current_keyframe = None
|
||||
|
||||
def has_index(self, index: int):
|
||||
return index >= 0 and index < len(self.keyframes)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.keyframes) == 0
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframeGroup()
|
||||
for keyframe in self.keyframes:
|
||||
c.keyframes.append(keyframe.clone())
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
||||
if self.is_empty():
|
||||
return False
|
||||
if curr_t == self._curr_t:
|
||||
return False
|
||||
prev_index = self._current_index
|
||||
prev_strength = self._current_strength
|
||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
||||
# if has next index, loop through and see if need to switch
|
||||
if self.has_index(self._current_index+1):
|
||||
for i in range(self._current_index+1, len(self.keyframes)):
|
||||
eval_c = self.keyframes[i]
|
||||
# check if start_t is greater or equal to curr_t
|
||||
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
|
||||
if eval_c.start_t >= curr_t:
|
||||
self._current_index = i
|
||||
self._current_strength = eval_c.strength
|
||||
self._current_keyframe = eval_c
|
||||
self._current_used_steps = 0
|
||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||
if self._current_keyframe.guarantee_steps > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
self._curr_t = curr_t
|
||||
# return True if keyframe changed, False if no change
|
||||
return prev_index != self._current_index and prev_strength != self._current_strength
|
||||
|
||||
|
||||
class InterpolationMethod:
|
||||
LINEAR = "linear"
|
||||
EASE_IN = "ease_in"
|
||||
EASE_OUT = "ease_out"
|
||||
EASE_IN_OUT = "ease_in_out"
|
||||
|
||||
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
||||
|
||||
@classmethod
|
||||
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
||||
diff = num_to - num_from
|
||||
if method == cls.LINEAR:
|
||||
weights = torch.linspace(num_from, num_to, length)
|
||||
elif method == cls.EASE_IN:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * np.power(index, 2) + num_from
|
||||
elif method == cls.EASE_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
||||
elif method == cls.EASE_IN_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
||||
else:
|
||||
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
||||
if reverse:
|
||||
weights = weights.flip(dims=(0,))
|
||||
return weights
|
||||
|
||||
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
if not objects:
|
||||
return objects
|
||||
elif len(objects) <= 1:
|
||||
return [x for x in objects]
|
||||
# now that we know we have to sort, do it following these rules:
|
||||
# a) if objects have same value of attribute, maintain their relative order
|
||||
# b) perform sorting of the groups of objects with same attributes
|
||||
unique_attrs = {}
|
||||
for o in objects:
|
||||
val_attr = getattr(o, attr)
|
||||
attr_list: list = unique_attrs.get(val_attr, list())
|
||||
attr_list.append(o)
|
||||
if val_attr not in unique_attrs:
|
||||
unique_attrs[val_attr] = attr_list
|
||||
# now that we have the unique attr values grouped together in relative order, sort them by key
|
||||
sorted_attrs = dict(sorted(unique_attrs.items()))
|
||||
# now flatten out the dict into a list to return
|
||||
sorted_list = []
|
||||
for object_list in sorted_attrs.values():
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
hook.weights = lora
|
||||
return hook_group
|
||||
|
||||
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
patches_model = None
|
||||
patches_clip = None
|
||||
if weights_model is not None:
|
||||
patches_model = {}
|
||||
for key in weights_model:
|
||||
patches_model[key] = ("model_as_lora", (weights_model[key],))
|
||||
if weights_clip is not None:
|
||||
patches_clip = {}
|
||||
for key in weights_clip:
|
||||
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
|
||||
hook.weights = patches_model
|
||||
hook.weights_clip = patches_clip
|
||||
hook.need_weight_init = False
|
||||
return hook_group
|
||||
|
||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||
if model is None:
|
||||
return None
|
||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||
if discard_model_sampling:
|
||||
# do not include ANY model_sampling components of the model that should act as a patch
|
||||
for key in list(patches_model.keys()):
|
||||
if key.startswith("model_sampling"):
|
||||
patches_model.pop(key, None)
|
||||
return patches_model
|
||||
|
||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||
strength_model: float, strength_clip: float):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook()
|
||||
hook_group.add(hook)
|
||||
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
logging.warning(f"NOT LOADED {x}")
|
||||
return (new_modelpatcher, new_clip, hook_group)
|
||||
|
||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||
hooks_key = 'hooks'
|
||||
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
|
||||
if hooks_key not in values:
|
||||
return
|
||||
if hooks_key not in c_dict:
|
||||
hooks_value = values.get(hooks_key, None)
|
||||
if hooks_value is not None:
|
||||
c_dict[hooks_key] = hooks_value
|
||||
return
|
||||
# otherwise, need to combine with minimum duplication via cache
|
||||
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
|
||||
cached_hooks = cache.get(hooks_tuple, None)
|
||||
if cached_hooks is None:
|
||||
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
|
||||
cache[hooks_tuple] = new_hooks
|
||||
c_dict[hooks_key] = new_hooks
|
||||
else:
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
if append_hooks and k == 'hooks':
|
||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||
else:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||
if hooks is None:
|
||||
return cond
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||
|
||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||
if timestep_range is None:
|
||||
return cond
|
||||
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
|
||||
"end_percent": timestep_range[1]})
|
||||
|
||||
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
|
||||
if mask is None:
|
||||
return cond
|
||||
set_area_to_bounds = False
|
||||
if set_cond_area != 'default':
|
||||
set_area_to_bounds = True
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
return conditioning_set_values(cond, {'mask': mask,
|
||||
'set_area_to_bounds': set_area_to_bounds,
|
||||
'mask_strength': strength})
|
||||
|
||||
def combine_conditioning(conds: list):
|
||||
combined_conds = []
|
||||
for cond in conds:
|
||||
combined_conds.extend(cond)
|
||||
return combined_conds
|
||||
|
||||
def combine_with_new_conds(conds: list, new_conds: list):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
final_conds = []
|
||||
for c in conds:
|
||||
# first, apply lora_hook to conditioning, if provided
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to conditioning
|
||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||
# apply timesteps, if present
|
||||
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
|
||||
# finally, apply mask to conditioning and store
|
||||
final_conds.append(c)
|
||||
return final_conds
|
||||
|
||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, masked_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to new conditioning, if provided
|
||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||
# apply timesteps, if present
|
||||
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, masked_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||
new_c = conditioning_set_values(new_c, {'default': True})
|
||||
# apply timesteps, if present
|
||||
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
@@ -11,7 +11,6 @@ import numpy as np
|
||||
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
|
||||
|
||||
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
|
||||
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
||||
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
||||
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
|
||||
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
|
||||
|
||||
@@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
||||
from . import utils
|
||||
from . import deis
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
@@ -43,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||
epsilon = 1e-5 # avoid log(0)
|
||||
x = torch.linspace(0, 1, n, device=device)
|
||||
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||
sigmas = clamp(torch.exp(lmb))
|
||||
return sigmas
|
||||
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
@@ -152,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -161,14 +175,42 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
if sigma_down == 0:
|
||||
x = denoised
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i + 1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if eta > 0:
|
||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
@@ -243,6 +285,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -269,6 +314,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
@@ -509,6 +586,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -541,6 +621,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||
|
||||
# logged_x = x.unsqueeze(0)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
if sigmas[i] == 1.0:
|
||||
sigma_s = 0.9999
|
||||
else:
|
||||
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_down - t_i
|
||||
s = t_i + r * h
|
||||
sigma_s = sigma_fn(s)
|
||||
# sigma_s = sigmas[i+1]
|
||||
sigma_s_i_ratio = sigma_s / sigmas[i]
|
||||
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
|
||||
D_i = model(u, sigma_s * s_in, **extra_args)
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
|
||||
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
@@ -1016,7 +1145,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
@@ -1043,8 +1171,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
x = denoised + d * sigma_down
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
old_uncond_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||
old_uncond_denoised = uncond_denoised
|
||||
return x
|
||||
|
||||
@@ -4,6 +4,7 @@ class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 0.3920, 0.4054, 0.4549],
|
||||
[-0.2634, -0.0196, 0.0653],
|
||||
[ 0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
[ 0.3651, 0.4232, 0.4341],
|
||||
[-0.2533, -0.0042, 0.1068],
|
||||
[ 0.1076, 0.1111, -0.0362],
|
||||
[-0.3165, -0.2492, -0.2188]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
|
||||
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
class SDXL_Playground_2_5(LatentFormat):
|
||||
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[ 0.0028, 0.0312, 0.0650],
|
||||
[ 0.1848, 0.0762, 0.0360],
|
||||
[ 0.0944, 0.0360, 0.0889],
|
||||
[ 0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[ 0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[ 0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[ 0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259]
|
||||
[-0.0922, -0.0175, 0.0749],
|
||||
[ 0.0311, 0.0633, 0.0954],
|
||||
[ 0.1994, 0.0927, 0.0458],
|
||||
[ 0.0856, 0.0339, 0.0902],
|
||||
[ 0.0587, 0.0272, -0.0496],
|
||||
[-0.0006, 0.1104, 0.0309],
|
||||
[ 0.0978, 0.0306, 0.0427],
|
||||
[-0.0042, 0.1038, 0.1358],
|
||||
[-0.0194, 0.0020, 0.0669],
|
||||
[-0.0488, 0.0130, -0.0268],
|
||||
[ 0.0922, 0.0988, 0.0951],
|
||||
[-0.0278, 0.0524, -0.0542],
|
||||
[ 0.0332, 0.0456, 0.0895],
|
||||
[-0.0069, -0.0030, -0.0810],
|
||||
[-0.0596, -0.0465, -0.0293],
|
||||
[-0.1448, -0.1463, -0.1189]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
|
||||
self.taesd_decoder_name = "taesd3_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -141,30 +145,234 @@ class StableAudio1(LatentFormat):
|
||||
latent_channels = 64
|
||||
|
||||
class Flux(SD3):
|
||||
latent_channels = 16
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0404, 0.0159, 0.0609],
|
||||
[ 0.0043, 0.0298, 0.0850],
|
||||
[ 0.0328, -0.0749, -0.0503],
|
||||
[-0.0245, 0.0085, 0.0549],
|
||||
[ 0.0966, 0.0894, 0.0530],
|
||||
[ 0.0035, 0.0399, 0.0123],
|
||||
[ 0.0583, 0.1184, 0.1262],
|
||||
[-0.0191, -0.0206, -0.0306],
|
||||
[-0.0324, 0.0055, 0.1001],
|
||||
[ 0.0955, 0.0659, -0.0545],
|
||||
[-0.0504, 0.0231, -0.0013],
|
||||
[ 0.0500, -0.0008, -0.0088],
|
||||
[ 0.0982, 0.0941, 0.0976],
|
||||
[-0.1233, -0.0280, -0.0897],
|
||||
[-0.0005, -0.0530, -0.0020],
|
||||
[-0.1273, -0.0932, -0.0680]
|
||||
[-0.0346, 0.0244, 0.0681],
|
||||
[ 0.0034, 0.0210, 0.0687],
|
||||
[ 0.0275, -0.0668, -0.0433],
|
||||
[-0.0174, 0.0160, 0.0617],
|
||||
[ 0.0859, 0.0721, 0.0329],
|
||||
[ 0.0004, 0.0383, 0.0115],
|
||||
[ 0.0405, 0.0861, 0.0915],
|
||||
[-0.0236, -0.0185, -0.0259],
|
||||
[-0.0245, 0.0250, 0.1180],
|
||||
[ 0.1008, 0.0755, -0.0421],
|
||||
[-0.0515, 0.0201, 0.0011],
|
||||
[ 0.0428, -0.0012, -0.0036],
|
||||
[ 0.0817, 0.0765, 0.0749],
|
||||
[-0.1264, -0.0522, -0.1103],
|
||||
[-0.0280, -0.0881, -0.0499],
|
||||
[-0.1262, -0.0982, -0.0778]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.taesd_decoder_name = "taef1_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
|
||||
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
|
||||
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
|
||||
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
|
||||
0.959253732819592, 0.8244560132752793, 0.917259975397747,
|
||||
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0069, -0.0045, 0.0018],
|
||||
[ 0.0154, -0.0692, -0.0274],
|
||||
[ 0.0333, 0.0019, 0.0206],
|
||||
[-0.1390, 0.0628, 0.1678],
|
||||
[-0.0725, 0.0134, -0.1898],
|
||||
[ 0.0074, -0.0270, -0.0209],
|
||||
[-0.0176, -0.0277, -0.0221],
|
||||
[ 0.5294, 0.5204, 0.3852],
|
||||
[-0.0326, -0.0446, -0.0143],
|
||||
[-0.0659, 0.0153, -0.0153],
|
||||
[ 0.0185, -0.0217, 0.0014],
|
||||
[-0.0396, -0.0495, -0.0281]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
|
||||
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
|
||||
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
|
||||
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
|
||||
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
|
||||
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
|
||||
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
|
||||
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
|
||||
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
|
||||
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
|
||||
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
|
||||
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
|
||||
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
|
||||
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
|
||||
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
|
||||
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
|
||||
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
|
||||
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
|
||||
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
|
||||
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
|
||||
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
|
||||
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
|
||||
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
|
||||
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
|
||||
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
|
||||
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
|
||||
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
|
||||
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
|
||||
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
|
||||
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
|
||||
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
|
||||
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
|
||||
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
|
||||
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
|
||||
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
|
||||
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
|
||||
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
|
||||
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
|
||||
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
|
||||
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
|
||||
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
|
||||
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
|
||||
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
|
||||
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
|
||||
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
|
||||
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
|
||||
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
|
||||
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
|
||||
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
|
||||
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
|
||||
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
|
||||
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
|
||||
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
|
||||
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
|
||||
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
|
||||
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
|
||||
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
|
||||
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
|
||||
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
|
||||
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
|
||||
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
|
||||
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
|
||||
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
|
||||
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
|
||||
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
|
||||
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
|
||||
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
|
||||
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
|
||||
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
|
||||
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
|
||||
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
|
||||
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
|
||||
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
|
||||
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
|
||||
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
|
||||
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
|
||||
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
|
||||
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
|
||||
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
|
||||
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
|
||||
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
|
||||
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
|
||||
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
|
||||
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
|
||||
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
|
||||
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
|
||||
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
|
||||
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
|
||||
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
|
||||
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
|
||||
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
|
||||
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
|
||||
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
|
||||
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
|
||||
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
|
||||
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
|
||||
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
|
||||
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
|
||||
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
|
||||
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
|
||||
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
|
||||
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
|
||||
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
|
||||
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
|
||||
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
|
||||
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
|
||||
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
|
||||
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
|
||||
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
|
||||
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
|
||||
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
|
||||
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
|
||||
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
|
||||
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
|
||||
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
|
||||
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
|
||||
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
|
||||
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
|
||||
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
|
||||
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
|
||||
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
|
||||
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
|
||||
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
|
||||
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
|
||||
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
|
||||
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
|
||||
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
scale_factor = 0.476986
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
[ 0.0696, 0.0795, 0.0518],
|
||||
[ 0.0135, -0.0945, -0.0282],
|
||||
[ 0.0108, -0.0250, -0.0765],
|
||||
[-0.0209, 0.0032, 0.0224],
|
||||
[-0.0804, -0.0254, -0.0639],
|
||||
[-0.0991, 0.0271, -0.0669],
|
||||
[-0.0646, -0.0422, -0.0400],
|
||||
[-0.0696, -0.0595, -0.0894],
|
||||
[-0.0799, -0.0208, -0.0375],
|
||||
[ 0.1166, 0.1627, 0.0962],
|
||||
[ 0.1165, 0.0432, 0.0407],
|
||||
[-0.2315, -0.1920, -0.1355],
|
||||
[-0.0270, 0.0401, -0.0821],
|
||||
[-0.0616, -0.0997, -0.0727],
|
||||
[ 0.0249, -0.0469, -0.1703]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Literal, Dict, Any
|
||||
from typing import Literal
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -97,7 +97,7 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False,
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
act = Activation1d(act) # noqa: F821 Activation1d is not defined
|
||||
|
||||
return act
|
||||
|
||||
|
||||
@@ -158,7 +158,6 @@ class RotaryEmbedding(nn.Module):
|
||||
def forward(self, t):
|
||||
# device = self.inv_freq.device
|
||||
device = t.device
|
||||
dtype = t.dtype
|
||||
|
||||
# t = t.to(torch.float32)
|
||||
|
||||
@@ -170,7 +169,7 @@ class RotaryEmbedding(nn.Module):
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
|
||||
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
@@ -229,9 +228,9 @@ class FeedForward(nn.Module):
|
||||
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
linear_in = nn.Sequential(
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
activation
|
||||
)
|
||||
|
||||
@@ -246,9 +245,9 @@ class FeedForward(nn.Module):
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
linear_in,
|
||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
linear_out,
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -346,18 +345,13 @@ class Attention(nn.Module):
|
||||
|
||||
# determine masking
|
||||
masks = []
|
||||
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
||||
|
||||
if input_mask is not None:
|
||||
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||
masks.append(~input_mask)
|
||||
|
||||
# Other masks will be added here later
|
||||
|
||||
if len(masks) > 0:
|
||||
final_attn_mask = ~or_reduce(masks)
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
n = q.shape[-2]
|
||||
|
||||
causal = self.causal if causal is None else causal
|
||||
|
||||
@@ -612,7 +606,9 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
@@ -643,9 +639,19 @@ class ContinuousTransformer(nn.Module):
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
@@ -874,7 +880,6 @@ class AudioDiffusionTransformer(nn.Module):
|
||||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, einsum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from torch import Tensor
|
||||
from typing import List, Union
|
||||
from einops import rearrange
|
||||
import math
|
||||
import comfy.ops
|
||||
|
||||
@@ -147,7 +147,6 @@ class DoubleAttention(nn.Module):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
seqlen = seqlen1 + seqlen2
|
||||
|
||||
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
@@ -382,7 +381,6 @@ class MMDiT(nn.Module):
|
||||
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
||||
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
||||
self.h_max, self.w_max = target_dim
|
||||
print("PE extended to", target_dim)
|
||||
|
||||
def pe_selection_index_based_on_dim(self, h, w):
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
@@ -437,7 +435,8 @@ class MMDiT(nn.Module):
|
||||
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
@@ -458,15 +457,36 @@ class MMDiT(nn.Module):
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
if len(self.double_layers) > 0:
|
||||
for layer in self.double_layers:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.double_layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
cx = torch.cat([c, x], dim=1)
|
||||
for layer in self.single_layers:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.single_layers):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from .common import LayerNorm2d_op
|
||||
|
||||
@@ -1,8 +1,30 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
pad = ()
|
||||
for i in range(img.ndim - 2):
|
||||
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
|
||||
|
||||
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
203
comfy/ldm/flux/controlnet.py
Normal file
203
comfy/ldm/flux/controlnet.py
Normal file
@@ -0,0 +1,203 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
#modified to support different types of flux controlnets
|
||||
|
||||
import torch
|
||||
import math
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class MistolineCondDownsamplBlock(nn.Module):
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
class MistolineControlnetBlock(nn.Module):
|
||||
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.linear(x))
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
self.main_model_double = 19
|
||||
self.main_model_single = 38
|
||||
|
||||
self.mistoline = mistoline
|
||||
# add ControlNet blocks
|
||||
if self.mistoline:
|
||||
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
self.controlnet_blocks.append(control_block())
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth_single_blocks):
|
||||
self.controlnet_single_blocks.append(control_block())
|
||||
|
||||
self.num_union_modes = num_union_modes
|
||||
self.controlnet_mode_embedder = None
|
||||
if self.num_union_modes > 0:
|
||||
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.latent_input = latent_input
|
||||
if control_latent_channels is None:
|
||||
control_latent_channels = self.in_channels
|
||||
else:
|
||||
control_latent_channels *= 2 * 2 #patch size
|
||||
|
||||
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if not self.latent_input:
|
||||
if self.mistoline:
|
||||
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control_type: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||
txt = torch.cat([control_cond, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
controlnet_double = ()
|
||||
|
||||
for i in range(len(self.double_blocks)):
|
||||
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
controlnet_single = ()
|
||||
|
||||
for i in range(len(self.single_blocks)):
|
||||
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||
if self.latent_input:
|
||||
out_input = ()
|
||||
for x in controlnet_double:
|
||||
out_input += (x,) * repeat
|
||||
else:
|
||||
out_input = (controlnet_double * repeat)
|
||||
|
||||
out = {"input": out_input[:self.main_model_double]}
|
||||
if len(controlnet_single) > 0:
|
||||
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||
out_output = ()
|
||||
if self.latent_input:
|
||||
for x in controlnet_single:
|
||||
out_output += (x,) * repeat
|
||||
else:
|
||||
out_output = (controlnet_single * repeat)
|
||||
out["output"] = out_output[:self.main_model_single]
|
||||
return out
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
elif self.mistoline:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_cond_block(hint)
|
||||
else:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_hint_block(hint)
|
||||
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
||||
@@ -6,6 +6,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
@@ -116,7 +114,7 @@ class Modulation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@@ -143,8 +141,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
@@ -162,23 +161,33 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||
if self.flipped_img_txt:
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = txt.clip(-65504, 65504)
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
@@ -219,7 +228,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@@ -228,12 +237,12 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = x.clip(-65504, 65504)
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
@@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -14,12 +16,10 @@ from .layers import (
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
@@ -29,6 +29,7 @@ class FluxParams:
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
@@ -38,13 +39,14 @@ class Flux(nn.Module):
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -83,7 +85,8 @@ class Flux(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@@ -94,7 +97,11 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
@@ -106,26 +113,82 @@ class Flux(nn.Module):
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y)
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
@@ -133,10 +196,10 @@ class Flux(nn.Module):
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
||||
25
comfy/ldm/flux/redux.py
Normal file
25
comfy/ldm/flux/redux.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
||||
557
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
557
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
@@ -0,0 +1,557 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
from .rope_mixed import (
|
||||
compute_mixed_rotation,
|
||||
create_position_matrix,
|
||||
)
|
||||
from .temporal_rope import apply_rotary_emb_qk_real
|
||||
from .utils import (
|
||||
AttentionPool,
|
||||
modulate,
|
||||
)
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.ops
|
||||
|
||||
|
||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||
# Normalize and modulate
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
||||
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||
|
||||
return x_modulated
|
||||
|
||||
|
||||
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||
# Apply tanh to gate
|
||||
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||
|
||||
# Normalize and apply gated scaling
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
||||
|
||||
# Apply residual connection
|
||||
output = x + x_normed
|
||||
|
||||
return output
|
||||
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_x: int,
|
||||
dim_y: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
update_y: bool = True,
|
||||
out_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
softmax_scale: Optional[float] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_x = dim_x
|
||||
self.dim_y = dim_y
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_x // num_heads
|
||||
self.attn_drop = attn_drop
|
||||
self.update_y = update_y
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.softmax_scale = softmax_scale
|
||||
if dim_x % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||
)
|
||||
|
||||
# Input layers.
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||
# Project text features to match visual features (dim_y -> dim_x)
|
||||
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||
self.proj_y = (
|
||||
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
|
||||
if update_y
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (B, N, dim_x)
|
||||
y: torch.Tensor, # (B, L, dim_y)
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
rope_sin = rope_rotation.get("rope_sin")
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
|
||||
# Process visual features
|
||||
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
# assert qkv_x.dtype == torch.bfloat16
|
||||
# qkv_x = all_to_all_collect_tokens(
|
||||
# qkv_x, self.num_heads
|
||||
# ) # (3, B, N, local_h, head_dim)
|
||||
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||
o[:, :y.shape[1]] = y
|
||||
|
||||
y = self.proj_y(o)
|
||||
# print("ox", x)
|
||||
# print("oy", y)
|
||||
return x, y
|
||||
|
||||
|
||||
class AsymmetricJointBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size_x: int,
|
||||
hidden_size_y: int,
|
||||
num_heads: int,
|
||||
*,
|
||||
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||
update_y: bool = True, # Whether to update text tokens in this block.
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.update_y = update_y
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
|
||||
if self.update_y:
|
||||
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
|
||||
else:
|
||||
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
|
||||
|
||||
# Self-attention:
|
||||
self.attn = AsymmetricAttention(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads=num_heads,
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||
self.mlp_x = FeedForward(
|
||||
in_features=hidden_size_x,
|
||||
hidden_size=mlp_hidden_dim_x,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# MLP for text not needed in last block.
|
||||
if self.update_y:
|
||||
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
||||
self.mlp_y = FeedForward(
|
||||
in_features=hidden_size_y,
|
||||
hidden_size=mlp_hidden_dim_y,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
|
||||
Args:
|
||||
x: (B, N, dim) tensor of visual tokens
|
||||
c: (B, dim) tensor of conditioned features
|
||||
y: (B, L, dim) tensor of text tokens
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim) tensor of visual tokens after block
|
||||
y: (B, L, dim) tensor of text tokens after block
|
||||
"""
|
||||
N = x.size(1)
|
||||
|
||||
c = F.silu(c)
|
||||
mod_x = self.mod_x(c)
|
||||
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
||||
|
||||
mod_y = self.mod_y(c)
|
||||
if self.update_y:
|
||||
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||
else:
|
||||
scale_msa_y = mod_y
|
||||
|
||||
# Self-attention block.
|
||||
x_attn, y_attn = self.attn(
|
||||
x,
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
assert x_attn.size(1) == N
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
if self.update_y:
|
||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||
|
||||
# MLP block.
|
||||
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||
if self.update_y:
|
||||
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||
|
||||
return x, y
|
||||
|
||||
def ff_block_x(self, x, scale_x, gate_x):
|
||||
x_mod = modulated_rmsnorm(x, scale_x)
|
||||
x_res = self.mlp_x(x_mod)
|
||||
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||
return x
|
||||
|
||||
def ff_block_y(self, y, scale_y, gate_y):
|
||||
y_mod = modulated_rmsnorm(y, scale_y)
|
||||
y_res = self.mlp_y(y_mod)
|
||||
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||
return y
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
out_channels,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
|
||||
)
|
||||
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = F.silu(c)
|
||||
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class AsymmDiTJoint(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
|
||||
Ingests text embeddings instead of a label.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size_x=1152,
|
||||
hidden_size_y=1152,
|
||||
depth=48,
|
||||
num_heads=16,
|
||||
mlp_ratio_x=8.0,
|
||||
mlp_ratio_y=4.0,
|
||||
use_t5: bool = False,
|
||||
t5_feat_dim: int = 4096,
|
||||
t5_token_length: int = 256,
|
||||
learn_sigma=True,
|
||||
patch_embed_bias: bool = True,
|
||||
timestep_mlp_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
timestep_scale: Optional[float] = None,
|
||||
use_extended_posenc: bool = False,
|
||||
posenc_preserve_area: bool = False,
|
||||
rope_theta: float = 10000.0,
|
||||
image_model=None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.head_dim = (
|
||||
hidden_size_x // num_heads
|
||||
) # Head dimension and count is determined by visual.
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.use_extended_posenc = use_extended_posenc
|
||||
self.posenc_preserve_area = posenc_preserve_area
|
||||
self.use_t5 = use_t5
|
||||
self.t5_token_length = t5_token_length
|
||||
self.t5_feat_dim = t5_feat_dim
|
||||
self.rope_theta = (
|
||||
rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
||||
)
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size_x,
|
||||
bias=patch_embed_bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
# Conditionings
|
||||
# Timestep
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if self.use_t5:
|
||||
# Caption Pooling (T5)
|
||||
self.t5_y_embedder = AttentionPool(
|
||||
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# Dense Embedding Projection (T5)
|
||||
self.t5_yproj = operations.Linear(
|
||||
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Initialize pos_frequencies as an empty parameter.
|
||||
self.pos_frequencies = nn.Parameter(
|
||||
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
assert not self.attend_to_padding
|
||||
|
||||
# for depth 48:
|
||||
# b = 0: AsymmetricJointBlock, update_y=True
|
||||
# b = 1: AsymmetricJointBlock, update_y=True
|
||||
# ...
|
||||
# b = 46: AsymmetricJointBlock, update_y=True
|
||||
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
||||
blocks = []
|
||||
for b in range(depth):
|
||||
# Joint multi-modal block
|
||||
update_y = b < depth - 1
|
||||
block = AsymmetricJointBlock(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads,
|
||||
mlp_ratio_x=mlp_ratio_x,
|
||||
mlp_ratio_y=mlp_ratio_y,
|
||||
update_y=update_y,
|
||||
attend_to_padding=attend_to_padding,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C=12, T, H, W) tensor of visual tokens
|
||||
|
||||
Returns:
|
||||
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
||||
"""
|
||||
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
t5_feat: torch.Tensor,
|
||||
t5_mask: torch.Tensor,
|
||||
):
|
||||
"""Prepare input and conditioning embeddings."""
|
||||
# Visual patch embeddings with positional encoding.
|
||||
T, H, W = x.shape[-3:]
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||
assert x.ndim == 3
|
||||
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
N = T * pH * pW
|
||||
assert x.size(1) == N
|
||||
pos = create_position_matrix(
|
||||
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||
) # (N, 3)
|
||||
rope_cos, rope_sin = compute_mixed_rotation(
|
||||
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||
) # Each are (N, num_heads, dim // 2)
|
||||
|
||||
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||
|
||||
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||
|
||||
c = c_t + t5_y_pool
|
||||
|
||||
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
||||
|
||||
return x, c, y_feat, rope_cos, rope_sin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: List[torch.Tensor],
|
||||
attention_mask: List[torch.Tensor],
|
||||
num_tokens=256,
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
control=None, transformer_options={}, **kwargs
|
||||
):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
y_feat = context
|
||||
y_mask = attention_mask
|
||||
sigma = timestep
|
||||
"""Forward pass of DiT.
|
||||
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
sigma: (B,) tensor of noise standard deviations
|
||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||
x, sigma, y_feat, y_mask
|
||||
)
|
||||
del y_mask
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(
|
||||
args["img"],
|
||||
args["vec"],
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||
T=T,
|
||||
hp=H // self.patch_size,
|
||||
wp=W // self.patch_size,
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
c=self.out_channels,
|
||||
)
|
||||
|
||||
return -x
|
||||
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from itertools import repeat
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
frequency_embedding_size: int = 256,
|
||||
*,
|
||||
bias: bool = True,
|
||||
timestep_scale: Optional[float] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.timestep_scale = timestep_scale
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
freqs.mul_(-math.log(max_period) / half).exp_()
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, out_dtype):
|
||||
if self.timestep_scale is not None:
|
||||
t = t * self.timestep_scale
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_size: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
# keep parameter count and computation constant compared to standard FFN
|
||||
hidden_size = int(2 * hidden_size / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
||||
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||
x = self.w2(F.silu(x) * gate)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.flatten = flatten
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = operations.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert norm_layer is None
|
||||
self.norm = (
|
||||
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, _C, T, H, W = x.shape
|
||||
if not self.dynamic_img_pad:
|
||||
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
else:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||
x = self.proj(x)
|
||||
|
||||
# Flatten temporal and spatial dimensions.
|
||||
if not self.flatten:
|
||||
raise NotImplementedError("Must flatten output.")
|
||||
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
|
||||
# import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def centers(start: float, stop, num, dtype=None, device=None):
|
||||
"""linspace through bin centers.
|
||||
|
||||
Args:
|
||||
start (float): Start of the range.
|
||||
stop (float): End of the range.
|
||||
num (int): Number of points.
|
||||
dtype (torch.dtype): Data type of the points.
|
||||
device (torch.device): Device of the points.
|
||||
|
||||
Returns:
|
||||
centers (Tensor): Centers of the bins. Shape: (num,).
|
||||
"""
|
||||
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
|
||||
# @functools.lru_cache(maxsize=1)
|
||||
def create_position_matrix(
|
||||
T: int,
|
||||
pH: int,
|
||||
pW: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
target_area: float = 36864,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
T: int - Temporal dimension
|
||||
pH: int - Height dimension after patchify
|
||||
pW: int - Width dimension after patchify
|
||||
|
||||
Returns:
|
||||
pos: [T * pH * pW, 3] - position matrix
|
||||
"""
|
||||
# Create 1D tensors for each dimension
|
||||
t = torch.arange(T, dtype=dtype)
|
||||
|
||||
# Positionally interpolate to area 36864.
|
||||
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
||||
# This automatically scales rope positions when the resolution changes.
|
||||
# We use a large target area so the model is more sensitive
|
||||
# to changes in the learned pos_frequencies matrix.
|
||||
scale = math.sqrt(target_area / (pW * pH))
|
||||
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
||||
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
||||
|
||||
# Use meshgrid to create 3D grids
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||
|
||||
# Stack and reshape the grids.
|
||||
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||
pos = pos.to(dtype=dtype, device=device)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def compute_mixed_rotation(
|
||||
freqs: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
||||
|
||||
Args:
|
||||
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
||||
pos: [N, 3] - position of each token
|
||||
num_heads: int
|
||||
|
||||
Returns:
|
||||
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||
"""
|
||||
assert freqs.ndim == 3
|
||||
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
||||
freqs_cos = torch.cos(freqs_sum)
|
||||
freqs_sin = torch.sin(freqs_sum)
|
||||
return freqs_cos, freqs_sin
|
||||
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
|
||||
# Based on Llama3 Implementation.
|
||||
import torch
|
||||
|
||||
|
||||
def apply_rotary_emb_qk_real(
|
||||
xqk: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
||||
|
||||
Args:
|
||||
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
||||
Can be either just query or just key, or both stacked along some batch or * dim.
|
||||
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
||||
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||
"""
|
||||
# Split the last dimension into even and odd parts
|
||||
xqk_even = xqk[..., 0::2]
|
||||
xqk_odd = xqk[..., 1::2]
|
||||
|
||||
# Apply rotation
|
||||
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
||||
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||
|
||||
# Interleave the results back into the original shape
|
||||
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||
return out
|
||||
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||
"""
|
||||
Pool tokens in x using mask.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Args:
|
||||
x: (B, L, D) tensor of tokens.
|
||||
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
Returns:
|
||||
pooled: (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
spatial_dim (int): Number of tokens in sequence length.
|
||||
embed_dim (int): Dimensionality of input tokens.
|
||||
num_heads (int): Number of attention heads.
|
||||
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
|
||||
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
D = x.size(2)
|
||||
|
||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||
|
||||
# Average non-padding token features. These will be used as the query.
|
||||
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||
|
||||
# Concat pooled features to input sequence.
|
||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||
|
||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||
q = self.to_q(x[:, 0]) # (B, D)
|
||||
|
||||
# Extract heads.
|
||||
head_dim = D // self.num_heads
|
||||
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
711
comfy/ldm/genmo/vae/model.py
Normal file
711
comfy/ldm/genmo/vae/model.py
Normal file
@@ -0,0 +1,711 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
class GroupNormSpatial(ops.GroupNorm):
|
||||
"""
|
||||
GroupNorm applied per-frame.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
|
||||
B, C, T, H, W = x.shape
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
# Run group norm in chunks.
|
||||
output = torch.empty_like(x)
|
||||
for b in range(0, B * T, chunk_size):
|
||||
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
||||
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||
|
||||
class PConv3d(ops.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
causal: bool = True,
|
||||
context_parallel: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.causal = causal
|
||||
self.context_parallel = context_parallel
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
stride = cast_tuple(stride, 3)
|
||||
height_pad = (kernel_size[1] - 1) // 2
|
||||
width_pad = (kernel_size[2] - 1) // 2
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(1, 1, 1),
|
||||
padding=(0, height_pad, width_pad),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Compute padding amounts.
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if self.causal:
|
||||
pad_front = context_size
|
||||
pad_back = 0
|
||||
else:
|
||||
pad_front = context_size // 2
|
||||
pad_back = context_size - pad_front
|
||||
|
||||
# Apply padding.
|
||||
assert self.padding_mode == "replicate" # DEBUG
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class Conv1x1(ops.Linear):
|
||||
"""*1x1 Conv implemented with a linear layer."""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
|
||||
super().__init__(in_features, out_features, *args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, *] or [B, *, C].
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
|
||||
"""
|
||||
x = x.movedim(1, -1)
|
||||
x = super().forward(x)
|
||||
x = x.movedim(-1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceTime(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
temporal_expansion: int,
|
||||
spatial_expansion: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# When printed, this module should show the temporal and spatial expansion factors.
|
||||
def extra_repr(self):
|
||||
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
|
||||
Returns:
|
||||
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
|
||||
"""
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
|
||||
st=self.temporal_expansion,
|
||||
sh=self.spatial_expansion,
|
||||
sw=self.spatial_expansion,
|
||||
)
|
||||
|
||||
# cp_rank, _ = cp.get_cp_rank_size()
|
||||
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||
# Drop the first self.temporal_expansion - 1 frames.
|
||||
# This is because we always want the 3x3x3 conv filter to only apply
|
||||
# to the first frame, and the first frame doesn't need to be repeated.
|
||||
assert all(x.shape)
|
||||
x = x[:, :, self.temporal_expansion - 1 :]
|
||||
assert all(x.shape)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def norm_fn(
|
||||
in_channels: int,
|
||||
affine: bool = True,
|
||||
):
|
||||
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block that preserves the spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
*,
|
||||
affine: bool = True,
|
||||
attn_block: Optional[nn.Module] = None,
|
||||
causal: bool = True,
|
||||
prune_bottleneck: bool = False,
|
||||
padding_mode: str,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
assert causal
|
||||
self.stack = nn.Sequential(
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
)
|
||||
|
||||
self.attn_block = attn_block if attn_block else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
residual = x
|
||||
x = self.stack(x)
|
||||
x = x + residual
|
||||
del residual
|
||||
|
||||
return self.attn_block(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
out_bias: bool = True,
|
||||
qk_norm: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute temporal self-attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
chunk_size: Chunk size for large tensors.
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
if T == 1:
|
||||
# No attention for single frame.
|
||||
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||
qkv = self.qkv(x)
|
||||
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||
x = self.out(x)
|
||||
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||
|
||||
# 1D temporal attention.
|
||||
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||
# Output: x with shape [B, num_heads, t, head_dim]
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
|
||||
|
||||
if self.qk_norm:
|
||||
q = F.normalize(q, p=2, dim=-1)
|
||||
k = F.normalize(k, p=2, dim=-1)
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||
|
||||
assert x.size(0) == q.size(0)
|
||||
|
||||
x = self.out(x)
|
||||
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
**attn_kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = norm_fn(dim)
|
||||
self.attn = Attention(dim, **attn_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.attn(self.norm(x))
|
||||
|
||||
|
||||
class CausalUpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
*,
|
||||
temporal_expansion: int = 2,
|
||||
spatial_expansion: int = 2,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = []
|
||||
for _ in range(num_res_blocks):
|
||||
blocks.append(block_fn(in_channels, **block_kwargs))
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# Change channels in the final convolution layer.
|
||||
self.proj = Conv1x1(
|
||||
in_channels,
|
||||
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||
)
|
||||
|
||||
self.d2st = DepthToSpaceTime(
|
||||
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
x = self.proj(x)
|
||||
x = self.d2st(x)
|
||||
return x
|
||||
|
||||
|
||||
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||
attn_block = AttentionBlock(channels) if has_attention else None
|
||||
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||
|
||||
|
||||
class DownsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks,
|
||||
*,
|
||||
temporal_reduction=2,
|
||||
spatial_reduction=2,
|
||||
**block_kwargs,
|
||||
):
|
||||
"""
|
||||
Downsample block for the VAE encoder.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
num_res_blocks: Number of residual blocks.
|
||||
temporal_reduction: Temporal reduction factor.
|
||||
spatial_reduction: Spatial reduction factor.
|
||||
"""
|
||||
super().__init__()
|
||||
layers = []
|
||||
|
||||
# Change the channel count in the strided convolution.
|
||||
# This lets the ResBlock have uniform channel count,
|
||||
# as in ConvNeXt.
|
||||
assert in_channels != out_channels
|
||||
layers.append(
|
||||
PConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
# First layer in each block always uses replicate padding
|
||||
padding_mode="replicate",
|
||||
bias=block_kwargs["bias"],
|
||||
)
|
||||
)
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(block_fn(out_channels, **block_kwargs))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||
num_freqs = (stop - start) // step
|
||||
assert inputs.ndim == 5
|
||||
C = inputs.size(1)
|
||||
|
||||
# Create Base 2 Fourier features.
|
||||
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
|
||||
assert num_freqs == len(freqs)
|
||||
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
|
||||
C = inputs.shape[1]
|
||||
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
|
||||
|
||||
# Interleaved repeat of input channels to match w.
|
||||
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
inputs,
|
||||
torch.sin(h),
|
||||
torch.cos(h),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
|
||||
super().__init__()
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Add Fourier features to inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input tensor. Shape: [B, C, T, H, W]
|
||||
|
||||
Returns:
|
||||
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
|
||||
"""
|
||||
return add_fourier_features(inputs, self.start, self.stop, self.step)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int = 3,
|
||||
latent_dim: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
temporal_expansions: Optional[List[int]] = None,
|
||||
spatial_expansions: Optional[List[int]] = None,
|
||||
has_attention: List[bool],
|
||||
output_norm: bool = True,
|
||||
nonlinearity: str = "silu",
|
||||
output_nonlinearity: str = "silu",
|
||||
causal: bool = True,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_channels = latent_dim
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.output_nonlinearity = output_nonlinearity
|
||||
assert nonlinearity == "silu"
|
||||
assert causal
|
||||
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
self.num_up_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||
|
||||
blocks = []
|
||||
|
||||
first_block = [
|
||||
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
] # Input layer.
|
||||
# First set of blocks preserve channel count.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
first_block.append(
|
||||
block_fn(
|
||||
ch[-1],
|
||||
has_attention=has_attention[-1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*first_block))
|
||||
|
||||
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
|
||||
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
|
||||
|
||||
upsample_block_fn = CausalUpsampleBlock
|
||||
|
||||
for i in range(self.num_up_blocks):
|
||||
block = upsample_block_fn(
|
||||
ch[-i - 1],
|
||||
ch[-i - 2],
|
||||
num_res_blocks=num_res_blocks[-i - 2],
|
||||
has_attention=has_attention[-i - 2],
|
||||
temporal_expansion=temporal_expansions[-i - 1],
|
||||
spatial_expansion=spatial_expansions[-i - 1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
assert not output_norm
|
||||
|
||||
# Last block. Preserve channel count.
|
||||
last_block = []
|
||||
for _ in range(num_res_blocks[0]):
|
||||
last_block.append(
|
||||
block_fn(
|
||||
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*last_block))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.output_proj = Conv1x1(ch[0], out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
|
||||
|
||||
Returns:
|
||||
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
|
||||
T + 1 = (t - 1) * 4.
|
||||
H = h * 16, W = w * 16.
|
||||
"""
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.output_nonlinearity == "silu":
|
||||
x = F.silu(x, inplace=not self.training)
|
||||
else:
|
||||
assert (
|
||||
not self.output_nonlinearity
|
||||
) # StyleGAN3 omits the to-RGB nonlinearity.
|
||||
|
||||
return self.output_proj(x).contiguous()
|
||||
|
||||
class LatentDistribution:
|
||||
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||
"""Initialize latent distribution.
|
||||
|
||||
Args:
|
||||
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
assert mean.shape == logvar.shape
|
||||
self.mean = mean
|
||||
self.logvar = logvar
|
||||
|
||||
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||
if temperature == 0.0:
|
||||
return self.mean
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||
else:
|
||||
assert noise.device == self.mean.device
|
||||
noise = noise.to(self.mean.dtype)
|
||||
|
||||
if temperature != 1.0:
|
||||
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||
|
||||
# Just Gaussian sample with no scaling of variance.
|
||||
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
latent_dim: int,
|
||||
temporal_reductions: List[int],
|
||||
spatial_reductions: List[int],
|
||||
prune_bottlenecks: List[bool],
|
||||
has_attentions: List[bool],
|
||||
affine: bool = True,
|
||||
bias: bool = True,
|
||||
input_is_conv_1x1: bool = False,
|
||||
padding_mode: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_reductions = temporal_reductions
|
||||
self.spatial_reductions = spatial_reductions
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.fourier_features = FourierFeatures()
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
num_down_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == num_down_blocks + 2
|
||||
|
||||
layers = (
|
||||
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||
if not input_is_conv_1x1
|
||||
else [Conv1x1(in_channels, ch[0])]
|
||||
)
|
||||
|
||||
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||
assert len(has_attentions) == num_down_blocks + 2
|
||||
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||
|
||||
for _ in range(num_res_blocks[0]):
|
||||
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||
prune_bottlenecks = prune_bottlenecks[1:]
|
||||
has_attentions = has_attentions[1:]
|
||||
|
||||
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||
for i in range(num_down_blocks):
|
||||
layer = DownsampleBlock(
|
||||
ch[i],
|
||||
ch[i + 1],
|
||||
num_res_blocks=num_res_blocks[i + 1],
|
||||
temporal_reduction=temporal_reductions[i],
|
||||
spatial_reduction=spatial_reductions[i],
|
||||
prune_bottleneck=prune_bottlenecks[i],
|
||||
has_attention=has_attentions[i],
|
||||
affine=affine,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
layers.append(layer)
|
||||
|
||||
# Additional blocks.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
# Output layers.
|
||||
self.output_norm = norm_fn(ch[-1])
|
||||
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||
|
||||
@property
|
||||
def temporal_downsample(self):
|
||||
return math.prod(self.temporal_reductions)
|
||||
|
||||
@property
|
||||
def spatial_downsample(self):
|
||||
return math.prod(self.spatial_reductions)
|
||||
|
||||
def forward(self, x) -> LatentDistribution:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||
logvar: Shape: [B, latent_dim, t, h, w].
|
||||
"""
|
||||
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||
x = self.fourier_features(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
x = self.output_norm(x)
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.output_proj(x)
|
||||
|
||||
means, logvar = torch.chunk(x, 2, dim=1)
|
||||
|
||||
assert means.ndim == 5
|
||||
assert logvar.shape == means.shape
|
||||
assert means.size(1) == self.latent_dim
|
||||
|
||||
return LatentDistribution(means, logvar)
|
||||
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
in_channels=15,
|
||||
base_channels=64,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
temporal_reductions=[1, 2, 3],
|
||||
spatial_reductions=[2, 2, 2],
|
||||
prune_bottlenecks=[False, False, False, False, False],
|
||||
has_attentions=[False, True, True, True, True],
|
||||
affine=True,
|
||||
bias=True,
|
||||
input_is_conv_1x1=True,
|
||||
padding_mode="replicate"
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
temporal_expansions=[1, 2, 3],
|
||||
spatial_expansions=[2, 2, 2],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
has_attention=[False, False, False, False, False],
|
||||
padding_mode="replicate",
|
||||
output_norm=False,
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x).mode()
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
330
comfy/ldm/hunyuan_video/model.py
Normal file
330
comfy/ldm/hunyuan_video/model.py
Normal file
@@ -0,0 +1,330 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding
|
||||
)
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunyuanVideoParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: list
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
|
||||
class TokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
mlp_hidden_dim = hidden_size * 4
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn.qkv(norm_x)
|
||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads,
|
||||
num_blocks,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TokenRefinerBlock(
|
||||
hidden_size=hidden_size,
|
||||
heads=heads,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
m = None
|
||||
if mask is not None:
|
||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||
m = m + m.transpose(2, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, m)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class TokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_dim,
|
||||
hidden_size,
|
||||
heads,
|
||||
num_blocks,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
mask,
|
||||
):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
return x
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
|
||||
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
||||
return out
|
||||
@@ -1,24 +1,17 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||
Mlp,
|
||||
TimestepEmbedder,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from .poolers import AttentionPool
|
||||
|
||||
import comfy.latent_formats
|
||||
from .models import HunYuanDiTBlock, calc_rope
|
||||
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
|
||||
class HunYuanControlNet(nn.Module):
|
||||
@@ -171,9 +164,6 @@ class HunYuanControlNet(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
@@ -250,9 +248,6 @@ class HunYuanDiT(nn.Module):
|
||||
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
@@ -287,7 +282,7 @@ class HunYuanDiT(nn.Module):
|
||||
style=None,
|
||||
return_dict=False,
|
||||
control=None,
|
||||
transformer_options=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
@@ -315,8 +310,7 @@ class HunYuanDiT(nn.Module):
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
encoder_hidden_states = context
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
@@ -364,6 +358,8 @@ class HunYuanDiT(nn.Module):
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
controls = None
|
||||
if control:
|
||||
controls = control.get("output", None)
|
||||
@@ -372,12 +368,23 @@ class HunYuanDiT(nn.Module):
|
||||
for layer, block in enumerate(self.blocks):
|
||||
if layer > self.depth // 2:
|
||||
if controls is not None:
|
||||
skip = skips.pop() + controls.pop()
|
||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
||||
skip = None
|
||||
|
||||
if ("double_block", layer) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
|
||||
|
||||
if layer < (self.depth // 2 - 1):
|
||||
skips.append(x)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
|
||||
527
comfy/ldm/lightricks/model.py
Normal file
527
comfy/ldm/lightricks/model.py
Normal file
@@ -0,0 +1,527 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
# else:
|
||||
# self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
"""
|
||||
For PixArt-Alpha.
|
||||
|
||||
Reference:
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
if guiding_latent_noise_scale > 0:
|
||||
if self.generator is None:
|
||||
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
||||
elif self.generator.device != x.device:
|
||||
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
||||
|
||||
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
||||
scale = guiding_latent_noise_scale * (input_ts ** 2)
|
||||
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
||||
|
||||
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
|
||||
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
x,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
output_width=orig_shape[4],
|
||||
output_num_frames=orig_shape[2],
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
||||
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
def patchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
output_height = output_height // self._patch_size[1]
|
||||
output_width = output_width // self._patch_size[2]
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q) ",
|
||||
f=output_num_frames,
|
||||
h=output_height,
|
||||
w=output_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
dilation = (dilation, 1, 1)
|
||||
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
padding = (0, height_pad, width_pad)
|
||||
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.conv.weight
|
||||
907
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
907
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
@@ -0,0 +1,907 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self.blocks_desc = blocks
|
||||
|
||||
in_channels = in_channels * patch_size**2
|
||||
output_channel = base_channels
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in blocks:
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
|
||||
self.down_blocks.append(block)
|
||||
|
||||
# out
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == "per_channel":
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
num_dims = sample.dim()
|
||||
|
||||
if num_dims == 4:
|
||||
# For shape (B, C, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
elif num_dims == 5:
|
||||
# For shape (B, C, F, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
causal (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use causal convolutions or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.layers_per_block = layers_per_block
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.blocks_desc = blocks
|
||||
|
||||
# Compute output channel to be product of all channel-multiplier blocks
|
||||
output_channel = base_channels
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
block_params = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
|
||||
self.up_blocks.append(block)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(
|
||||
torch.tensor(1000.0, dtype=torch.float32)
|
||||
)
|
||||
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
output_channel * 2, 0, operations=ops,
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
timestep=scaled_timestep.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=sample.shape[0],
|
||||
hidden_dtype=sample.dtype,
|
||||
)
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, embedded_timestep.shape[-1], 1, 1, 1
|
||||
)
|
||||
ada_values = self.last_scale_shift_table[
|
||||
None, ..., None, None, None
|
||||
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
|
||||
batch_size,
|
||||
2,
|
||||
-1,
|
||||
embedded_timestep.shape[-3],
|
||||
embedded_timestep.shape[-2],
|
||||
embedded_timestep.shape[-1],
|
||||
)
|
||||
shift, scale = ada_values.unbind(dim=1)
|
||||
sample = sample * (1 + scale) + shift
|
||||
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
in_channels, height, width)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
in_channels * 4, 0, operations=ops,
|
||||
)
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
timestep_embed = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep=timestep.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
timestep_embed = timestep_embed.view(
|
||||
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
||||
)
|
||||
|
||||
for resnet in self.res_blocks:
|
||||
hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(
|
||||
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
||||
):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = (
|
||||
math.prod(stride) * in_channels // out_channels_reduction_factor
|
||||
)
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
|
||||
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, "b d h w c -> b c d h w")
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A Resnet block.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to be `None`):
|
||||
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
||||
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.inject_noise = inject_noise
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm1 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
self.conv1 = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm2 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
self.conv2 = make_conv_nd(
|
||||
dims,
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||
|
||||
self.conv_shortcut = (
|
||||
make_linear_nd(
|
||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||
)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.norm3 = (
|
||||
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
|
||||
def _feed_spatial_noise(
|
||||
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# similar to the "explicit noise inputs" method in style-gan
|
||||
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
||||
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
||||
hidden_states = hidden_states + scaled_noise
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
ada_values = self.scale_shift_table[
|
||||
None, ..., None, None, None
|
||||
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
|
||||
batch_size,
|
||||
4,
|
||||
-1,
|
||||
timestep.shape[-3],
|
||||
timestep.shape[-2],
|
||||
timestep.shape[-1],
|
||||
)
|
||||
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
|
||||
|
||||
hidden_states = hidden_states * (1 + scale1) + shift1
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||
|
||||
if self.inject_noise:
|
||||
hidden_states = self._feed_spatial_noise(
|
||||
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
hidden_states = hidden_states * (1 + scale2) + shift2
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||
|
||||
if self.inject_noise:
|
||||
hidden_states = self._feed_spatial_noise(
|
||||
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
batch_size = input_tensor.shape[0]
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self, version=0):
|
||||
super().__init__()
|
||||
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"decoder_blocks": [
|
||||
["res_x", {"num_layers": 5, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 6, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 7, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 8, "inject_noise": False}]
|
||||
],
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_all", {}],
|
||||
["res_x_y", 1],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["compress_all", {}],
|
||||
["res_x_y", 1],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["compress_all", {}],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["res_x", {"num_layers": 4}]
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
"timestep_conditioning": True,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=config.get("timestep_conditioning", False),
|
||||
)
|
||||
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def encode(self, x):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||
|
||||
82
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
82
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
|
||||
from .dual_conv3d import DualConv3d
|
||||
from .causal_conv3d import CausalConv3d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def make_conv_nd(
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def make_linear_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias=True,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
elif dims == 3 or dims == (2, 1):
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class DualConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if kernel_size == (1, 1, 1):
|
||||
raise ValueError(
|
||||
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
||||
)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation, dilation)
|
||||
|
||||
# Set parameters for convolutions
|
||||
self.groups = groups
|
||||
self.bias = bias
|
||||
|
||||
# Define the size of the channels after the first convolution
|
||||
intermediate_channels = (
|
||||
out_channels if in_channels < out_channels else in_channels
|
||||
)
|
||||
|
||||
# Define parameters for the first convolution
|
||||
self.weight1 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
intermediate_channels,
|
||||
in_channels // groups,
|
||||
1,
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
)
|
||||
)
|
||||
self.stride1 = (1, stride[1], stride[2])
|
||||
self.padding1 = (0, padding[1], padding[2])
|
||||
self.dilation1 = (1, dilation[1], dilation[2])
|
||||
if bias:
|
||||
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
||||
else:
|
||||
self.register_parameter("bias1", None)
|
||||
|
||||
# Define parameters for the second convolution
|
||||
self.weight2 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
||||
)
|
||||
)
|
||||
self.stride2 = (stride[0], 1, 1)
|
||||
self.padding2 = (padding[0], 0, 0)
|
||||
self.dilation2 = (dilation[0], 1, 1)
|
||||
if bias:
|
||||
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter("bias2", None)
|
||||
|
||||
# Initialize weights and biases
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
|
||||
if self.bias:
|
||||
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
||||
bound1 = 1 / math.sqrt(fan_in1)
|
||||
nn.init.uniform_(self.bias1, -bound1, bound1)
|
||||
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
||||
bound2 = 1 / math.sqrt(fan_in2)
|
||||
nn.init.uniform_(self.bias2, -bound2, bound2)
|
||||
|
||||
def forward(self, x, use_conv3d=False, skip_time_conv=False):
|
||||
if use_conv3d:
|
||||
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
||||
else:
|
||||
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
||||
|
||||
def forward_with_3d(self, x, skip_time_conv):
|
||||
# First convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight1,
|
||||
self.bias1,
|
||||
self.stride1,
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
return x
|
||||
|
||||
# Second convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight2,
|
||||
self.bias2,
|
||||
self.stride2,
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_2d(self, x, skip_time_conv):
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
# First 2D convolution
|
||||
x = rearrange(x, "b c d h w -> (b d) c h w")
|
||||
# Squeeze the depth dimension out of weight1 since it's 1
|
||||
weight1 = self.weight1.squeeze(2)
|
||||
# Select stride, padding, and dilation for the 2D convolution
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
if skip_time_conv:
|
||||
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
||||
return x
|
||||
|
||||
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
||||
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
||||
|
||||
# Reshape weight2 to match the expected dimensions for conv1d
|
||||
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
||||
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.weight2
|
||||
|
||||
|
||||
def test_dual_conv3d_consistency():
|
||||
# Initialize parameters
|
||||
in_channels = 3
|
||||
out_channels = 5
|
||||
kernel_size = (3, 3, 3)
|
||||
stride = (2, 2, 2)
|
||||
padding = (1, 1, 1)
|
||||
|
||||
# Create an instance of the DualConv3d class
|
||||
dual_conv3d = DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Example input tensor
|
||||
test_input = torch.randn(1, 3, 10, 10, 10)
|
||||
|
||||
# Perform forward passes with both 3D and 2D settings
|
||||
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
|
||||
output_2d = dual_conv3d(test_input, use_conv3d=False)
|
||||
|
||||
# Assert that the outputs from both methods are sufficiently close
|
||||
assert torch.allclose(
|
||||
output_conv3d, output_2d, atol=1e-6
|
||||
), "Outputs are not consistent between 3D and 2D convolutions."
|
||||
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self, dim=1, eps=1e-8):
|
||||
super(PixelNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import math
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
|
||||
@@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
def get_input(self, batch) -> Any:
|
||||
raise NotImplementedError()
|
||||
@@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
logpy.info(f"{context}: Switched to EMA weights")
|
||||
logging.info(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
logpy.info(f"{context}: Restored training weights")
|
||||
logging.info(f"{context}: Restored training weights")
|
||||
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("encode()-method of abstract base class called")
|
||||
@@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
@@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||
self.regularization = instantiate_from_config(
|
||||
regularizer_config
|
||||
)
|
||||
|
||||
@@ -160,12 +162,19 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
|
||||
|
||||
if ddconfig.get("conv3d", False):
|
||||
conv_op = comfy.ops.disable_weight_init.Conv3d
|
||||
else:
|
||||
conv_op = comfy.ops.disable_weight_init.Conv2d
|
||||
|
||||
self.quant_conv = conv_op(
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
|
||||
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
|
||||
@@ -15,6 +15,9 @@ if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
from sageattention import sageattn
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -157,8 +160,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
@@ -177,9 +178,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
mem_free_total, _ = model_management.get_free_memory(query.device, True)
|
||||
|
||||
kv_chunk_size_min = None
|
||||
kv_chunk_size = None
|
||||
@@ -230,7 +230,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
@@ -299,7 +298,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
if mask.shape[1] == 1:
|
||||
s1 += mask
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
@@ -341,12 +343,9 @@ except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
@@ -358,41 +357,57 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
# b h k d -> b k h d
|
||||
q, k, v = map(
|
||||
lambda t: t.permute(0, 2, 1, 3),
|
||||
(q, k, v),
|
||||
)
|
||||
# actually do the reshaping
|
||||
else:
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
# add a singleton batch dimension
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a singleton heads dimension
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
# pad to a multiple of 8
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
|
||||
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
|
||||
# in flux, this matrix ends up being over 1GB
|
||||
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
|
||||
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
# doesn't this remove the padding again??
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
mask = mask.expand(b, heads, -1, -1)
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if skip_reshape:
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
||||
SDP_BATCH_LIMIT = 2**15
|
||||
else:
|
||||
#TODO: other GPUs ?
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
@@ -404,27 +419,85 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||
for i in range(0, b, SDP_BATCH_LIMIT):
|
||||
m = mask
|
||||
if mask is not None:
|
||||
if mask.shape[0] > 1:
|
||||
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||
q[i : i + SDP_BATCH_LIMIT],
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
attn_mask=m,
|
||||
dropout_p=0.0, is_causal=False
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
tensor_layout="NHD"
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if tensor_layout == "HND":
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
logging.info("Using xformers cross attention")
|
||||
if model_management.sage_attention_enabled():
|
||||
logging.info("Using sage attention")
|
||||
optimized_attention = attention_sage
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch cross attention")
|
||||
logging.info("Using pytorch attention")
|
||||
optimized_attention = attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
logging.info("Using split optimization for cross attention")
|
||||
logging.info("Using split optimization for attention")
|
||||
optimized_attention = attention_split
|
||||
else:
|
||||
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from ..attention import optimized_attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
import comfy.ops
|
||||
@@ -72,45 +71,33 @@ class PatchEmbed(nn.Module):
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = True,
|
||||
padding_mode='circular',
|
||||
conv3d=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
try:
|
||||
len(patch_size)
|
||||
self.patch_size = patch_size
|
||||
except:
|
||||
if conv3d:
|
||||
self.patch_size = (patch_size, patch_size, patch_size)
|
||||
else:
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
self.padding_mode = padding_mode
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
if conv3d:
|
||||
self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# if self.img_size is not None:
|
||||
# if self.strict_img_size:
|
||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
# elif not self.dynamic_img_pad:
|
||||
# _assert(
|
||||
# H % self.patch_size[0] == 0,
|
||||
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
# )
|
||||
# _assert(
|
||||
# W % self.patch_size[1] == 0,
|
||||
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
# )
|
||||
if self.dynamic_img_pad:
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
||||
x = self.proj(x)
|
||||
@@ -266,8 +253,6 @@ def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
@@ -326,9 +311,9 @@ class SelfAttention(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv = self.pre_attention(x)
|
||||
q, k, v = self.pre_attention(x)
|
||||
x = optimized_attention(
|
||||
qkv, num_heads=self.num_heads
|
||||
q, k, v, heads=self.num_heads
|
||||
)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
@@ -355,29 +340,9 @@ class RMSNorm(torch.nn.Module):
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
@@ -437,6 +402,7 @@ class DismantledBlock(nn.Module):
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -460,6 +426,24 @@ class DismantledBlock(nn.Module):
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.x_block_self_attn = True
|
||||
self.attn2 = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
self.x_block_self_attn = False
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = operations.LayerNorm(
|
||||
@@ -486,7 +470,11 @@ class DismantledBlock(nn.Module):
|
||||
multiple_of=256,
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if not scale_mod_only:
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
@@ -547,14 +535,64 @@ class DismantledBlock(nn.Module):
|
||||
)
|
||||
return x
|
||||
|
||||
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert self.x_block_self_attn
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return qkv, qkv2, (
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
)
|
||||
|
||||
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
|
||||
assert not self.pre_only
|
||||
attn1 = self.attn.post_attention(attn)
|
||||
attn2 = self.attn2.post_attention(attn2)
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
if self.x_block_self_attn:
|
||||
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
|
||||
attn, _ = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
attn2, _ = optimized_attention(
|
||||
qkv2[0], qkv2[1], qkv2[2],
|
||||
num_heads=self.attn2.num_heads,
|
||||
)
|
||||
return self.post_attention_x(attn, attn2, *intermediates)
|
||||
else:
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
@@ -569,7 +607,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o = []
|
||||
for t in range(3):
|
||||
@@ -577,8 +618,8 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
qkv = tuple(o)
|
||||
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=x_block.attn.num_heads,
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=x_block.attn.num_heads,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
@@ -590,7 +631,14 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
|
||||
else:
|
||||
context = None
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
if x_block.x_block_self_attn:
|
||||
attn2 = optimized_attention(
|
||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||
heads=x_block.attn2.num_heads,
|
||||
)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
return context, x
|
||||
|
||||
|
||||
@@ -605,8 +653,13 @@ class JointBlock(nn.Module):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=x_block_self_attn,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(
|
||||
@@ -662,7 +715,7 @@ class SelfAttentionContext(nn.Module):
|
||||
def forward(self, x):
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.dim_head)
|
||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
||||
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
|
||||
return self.proj(x)
|
||||
|
||||
class ContextProcessorBlock(nn.Module):
|
||||
@@ -721,9 +774,12 @@ class MMDiT(nn.Module):
|
||||
qk_norm: Optional[str] = None,
|
||||
qkv_bias: bool = True,
|
||||
context_processor_layers = None,
|
||||
x_block_self_attn: bool = False,
|
||||
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||
context_size = 4096,
|
||||
num_blocks = None,
|
||||
final_layer = True,
|
||||
skip_blocks = False,
|
||||
dtype = None, #TODO
|
||||
device = None,
|
||||
operations = None,
|
||||
@@ -738,6 +794,7 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||
|
||||
# hidden_size = default(hidden_size, 64 * depth)
|
||||
# num_heads = default(num_heads, hidden_size // 64)
|
||||
@@ -795,26 +852,28 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed = None
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
if not skip_blocks:
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
@@ -877,7 +936,9 @@ class MMDiT(nn.Module):
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
@@ -889,14 +950,25 @@ class MMDiT(nn.Module):
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
blocks = len(self.joint_blocks)
|
||||
for i in range(blocks):
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||
context = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
@@ -914,6 +986,7 @@ class MMDiT(nn.Module):
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
@@ -935,7 +1008,7 @@ class MMDiT(nn.Module):
|
||||
if context is not None:
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context, control)
|
||||
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x[:,:,:hw[-2],:hw[-1]]
|
||||
@@ -949,7 +1022,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import Optional, Any
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
@@ -44,51 +43,100 @@ def Normalize(in_channels, num_groups=32):
|
||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class VideoConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.padding_mode = padding_mode
|
||||
if padding != 0:
|
||||
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
|
||||
else:
|
||||
kwargs["padding"] = padding
|
||||
|
||||
self.padding = padding
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
if self.padding != 0:
|
||||
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
|
||||
return self.conv(x)
|
||||
|
||||
def interpolate_up(x, scale_factor):
|
||||
try:
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
orig_shape = list(x.shape)
|
||||
out_shape = orig_shape[:2]
|
||||
for i in range(len(orig_shape) - 2):
|
||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
||||
return out
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
if self.with_conv:
|
||||
self.conv = ops.Conv2d(in_channels,
|
||||
self.conv = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
try:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
b, c, h, w = x.shape
|
||||
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
|
||||
del x
|
||||
x = out
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, (int, float)):
|
||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||
|
||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||
t = x.shape[2]
|
||||
if t > 1:
|
||||
a, b = x.split((1, t - 1), dim=2)
|
||||
del x
|
||||
b = interpolate_up(b, scale_factor)
|
||||
else:
|
||||
a = x
|
||||
|
||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
||||
if t > 1:
|
||||
x = torch.cat((a, b), dim=2)
|
||||
else:
|
||||
x = a
|
||||
else:
|
||||
x = interpolate_up(x, scale_factor)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = ops.Conv2d(in_channels,
|
||||
self.conv = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
stride=stride,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
if x.ndim == 4:
|
||||
pad = (0, 1, 0, 1)
|
||||
mode = "constant"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||
elif x.ndim == 5:
|
||||
pad = (1, 1, 1, 1, 2, 0)
|
||||
mode = "replicate"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
@@ -97,7 +145,7 @@ class Downsample(nn.Module):
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
dropout, temb_channels=512, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
@@ -106,7 +154,7 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = ops.Conv2d(in_channels,
|
||||
self.conv1 = conv_op(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -116,20 +164,20 @@ class ResnetBlock(nn.Module):
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.conv2 = ops.Conv2d(out_channels,
|
||||
self.conv2 = conv_op(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = ops.Conv2d(in_channels,
|
||||
self.conv_shortcut = conv_op(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = ops.Conv2d(in_channels,
|
||||
self.nin_shortcut = conv_op(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
@@ -163,7 +211,6 @@ def slice_attention(q, k, v):
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
@@ -196,21 +243,25 @@ def slice_attention(q, k, v):
|
||||
|
||||
def normal_attention(q, k, v):
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
orig_shape = q.shape
|
||||
b = orig_shape[0]
|
||||
c = orig_shape[1]
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
q = q.reshape(b, c, -1)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, -1) # b,c,hw
|
||||
v = v.reshape(b, c, -1)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
h_ = r1.reshape(orig_shape)
|
||||
del r1
|
||||
return h_
|
||||
|
||||
def xformers_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
orig_shape = q.shape
|
||||
B = orig_shape[0]
|
||||
C = orig_shape[1]
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
@@ -218,14 +269,16 @@ def xformers_attention(q, k, v):
|
||||
|
||||
try:
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
out = out.transpose(1, 2).reshape(orig_shape)
|
||||
except NotImplementedError:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||
return out
|
||||
|
||||
def pytorch_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
orig_shape = q.shape
|
||||
B = orig_shape[0]
|
||||
C = orig_shape[1]
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
@@ -233,35 +286,35 @@ def pytorch_attention(q, k, v):
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||
return out
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = ops.Conv2d(in_channels,
|
||||
self.q = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = ops.Conv2d(in_channels,
|
||||
self.k = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = ops.Conv2d(in_channels,
|
||||
self.v = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels,
|
||||
self.proj_out = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
@@ -291,8 +344,8 @@ class AttnBlock(nn.Module):
|
||||
return x+h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
return AttnBlock(in_channels)
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
|
||||
return AttnBlock(in_channels, conv_op=conv_op)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -451,6 +504,7 @@ class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
conv3d=False, time_compress=None,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
@@ -461,8 +515,15 @@ class Encoder(nn.Module):
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
mid_attn_conv_op = ops.Conv2d
|
||||
|
||||
# downsampling
|
||||
self.conv_in = ops.Conv2d(in_channels,
|
||||
self.conv_in = conv_op(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -481,15 +542,20 @@ class Encoder(nn.Module):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
dropout=dropout,
|
||||
conv_op=conv_op))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
stride = 2
|
||||
if time_compress is not None:
|
||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||
stride = (1, 2, 2)
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
@@ -498,16 +564,18 @@ class Encoder(nn.Module):
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
dropout=dropout,
|
||||
conv_op=conv_op)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
dropout=dropout,
|
||||
conv_op=conv_op)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = ops.Conv2d(block_in,
|
||||
self.conv_out = conv_op(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -545,9 +613,10 @@ class Decoder(nn.Module):
|
||||
conv_out_op=ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
conv3d=False,
|
||||
time_compress=None,
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
@@ -557,8 +626,15 @@ class Decoder(nn.Module):
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
mid_attn_conv_op = ops.Conv2d
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
@@ -566,7 +642,7 @@ class Decoder(nn.Module):
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = ops.Conv2d(z_channels,
|
||||
self.conv_in = conv_op(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -577,12 +653,14 @@ class Decoder(nn.Module):
|
||||
self.mid.block_1 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = attn_op(block_in)
|
||||
dropout=dropout,
|
||||
conv_op=conv_op)
|
||||
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
|
||||
self.mid.block_2 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
dropout=dropout,
|
||||
conv_op=conv_op)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
@@ -594,15 +672,21 @@ class Decoder(nn.Module):
|
||||
block.append(resnet_op(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
dropout=dropout,
|
||||
conv_op=conv_op))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(attn_op(block_in))
|
||||
attn.append(attn_op(block_in, conv_op=conv_op))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
scale_factor = 2.0
|
||||
if time_compress is not None:
|
||||
if i_level > math.log2(time_compress):
|
||||
scale_factor = (1.0, 2.0, 2.0)
|
||||
|
||||
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@ import logging
|
||||
from .util import (
|
||||
checkpoint,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
AlphaBlender,
|
||||
)
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -47,6 +47,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@@ -819,6 +828,13 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
@@ -842,6 +858,11 @@ class UNetModel(nn.Module):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if "emb_patch" in transformer_patches:
|
||||
patch = transformer_patches["emb_patch"]
|
||||
for p in patch:
|
||||
emb = p(emb, self.model_channels, transformer_options)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@@ -4,7 +4,6 @@ import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from .util import extract_into_tensor, make_beta_schedule
|
||||
from comfy.ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
@@ -131,7 +131,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
@@ -143,8 +143,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
logging.info(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
@@ -30,10 +30,10 @@ class DiagonalGaussianDistribution(object):
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
|
||||
@@ -22,7 +22,6 @@ except ImportError:
|
||||
from typing import Optional, NamedTuple, List
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from torch import Tensor
|
||||
from typing import List
|
||||
|
||||
from comfy import model_management
|
||||
@@ -172,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
||||
attn_scores /= summed
|
||||
@@ -234,6 +233,8 @@ def efficient_dot_product_attention(
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
if mask.shape[1] == 1:
|
||||
return mask
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import functools
|
||||
from typing import Callable, Iterable, Union
|
||||
from typing import Iterable, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
@@ -194,6 +194,7 @@ def make_time_attn(
|
||||
attn_kwargs=None,
|
||||
alpha: float = 0,
|
||||
merge_strategy: str = "learned",
|
||||
conv_op=ops.Conv2d,
|
||||
):
|
||||
return partialclass(
|
||||
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
||||
|
||||
380
comfy/ldm/pixart/blocks.py
Normal file
380
comfy/ldm/pixart/blocks.py
Normal file
@@ -0,0 +1,380 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
# if model_management.xformers_enabled():
|
||||
# import xformers.ops
|
||||
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
|
||||
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
|
||||
# else:
|
||||
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class MultiHeadCrossAttention(nn.Module):
|
||||
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
|
||||
super(MultiHeadCrossAttention, self).__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_model // num_heads
|
||||
|
||||
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, cond, mask=None):
|
||||
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
assert mask is None # TODO?
|
||||
# # TODO: xformers needs separate mask logic here
|
||||
# if model_management.xformers_enabled():
|
||||
# attn_bias = None
|
||||
# if mask is not None:
|
||||
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
||||
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
||||
# else:
|
||||
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||
# attn_mask = None
|
||||
# mask = torch.ones(())
|
||||
# if mask is not None and len(mask) > 1:
|
||||
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||
# # But depth doesn't matter as tensors can expand in that dimension
|
||||
# attn_mask_template = torch.ones(
|
||||
# [q.shape[2] // B, mask[0]],
|
||||
# dtype=torch.bool,
|
||||
# device=q.device
|
||||
# )
|
||||
# attn_mask = torch.block_diag(attn_mask_template)
|
||||
#
|
||||
# # create a mask on the diagonal for each mask in the batch
|
||||
# for _ in range(B - 1):
|
||||
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||
|
||||
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionKVCompress(nn.Module):
|
||||
"""Multi-head Attention block with KV token compression and qk norm."""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
||||
"""
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1 and sampling == 'conv':
|
||||
# Avg Conv Init.
|
||||
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
|
||||
# self.sr.weight.data.fill_(1/sr_ratio**2)
|
||||
# self.sr.bias.data.zero_()
|
||||
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
if qk_norm:
|
||||
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.q_norm = nn.Identity()
|
||||
self.k_norm = nn.Identity()
|
||||
|
||||
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
|
||||
if sampling is None or scale_factor == 1:
|
||||
return tensor
|
||||
B, N, C = tensor.shape
|
||||
|
||||
if sampling == 'uniform_every':
|
||||
return tensor[:, ::scale_factor], int(N // scale_factor)
|
||||
|
||||
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
||||
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
|
||||
new_N = new_H * new_W
|
||||
|
||||
if sampling == 'ave':
|
||||
tensor = F.interpolate(
|
||||
tensor, scale_factor=1 / scale_factor, mode='nearest'
|
||||
).permute(0, 2, 3, 1)
|
||||
elif sampling == 'uniform':
|
||||
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
|
||||
elif sampling == 'conv':
|
||||
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
|
||||
tensor = self.norm(tensor)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return tensor.reshape(B, new_N, C).contiguous(), new_N
|
||||
|
||||
def forward(self, x, mask=None, HW=None, block_id=None):
|
||||
B, N, C = x.shape # 2 4096 1152
|
||||
new_N = N
|
||||
if HW is None:
|
||||
H = W = int(N ** 0.5)
|
||||
else:
|
||||
H, W = HW
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
|
||||
q, k, v = qkv.unbind(2)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# KV compression
|
||||
if self.sr_ratio > 1:
|
||||
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
|
||||
if mask is not None:
|
||||
raise NotImplementedError("Attn mask logic not added for self attention")
|
||||
|
||||
# This is never called at the moment
|
||||
# attn_bias = None
|
||||
# if mask is not None:
|
||||
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
||||
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
||||
|
||||
# attention 2
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
|
||||
|
||||
x = x.view(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x, t):
|
||||
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class MaskFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_decoder(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class SizeEmbedder(TimestepEmbedder):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.outdim = hidden_size
|
||||
|
||||
def forward(self, s, bs):
|
||||
if s.ndim == 1:
|
||||
s = s[:, None]
|
||||
assert s.ndim == 2
|
||||
if s.shape[0] != bs:
|
||||
s = s.repeat(bs//s.shape[0], 1)
|
||||
assert s.shape[0] == bs
|
||||
b, dims = s.shape[0], s.shape[1]
|
||||
s = rearrange(s, "b d -> (b d)")
|
||||
s_freq = timestep_embedding(s, self.frequency_embedding_size)
|
||||
s_emb = self.mlp(s_freq.to(s.dtype))
|
||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
return s_emb
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.y_proj = Mlp(
|
||||
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
if train:
|
||||
assert caption.shape[2:] == self.y_embedding.shape
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
caption = self.token_drop(caption, force_drop_ids)
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class CaptionEmbedderDoubleBr(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = Mlp(
|
||||
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
|
||||
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, global_caption, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return global_caption, caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
assert caption.shape[2: ] == self.y_embedding.shape
|
||||
global_caption = caption.mean(dim=2).squeeze()
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
|
||||
y_embed = self.proj(global_caption)
|
||||
return y_embed, caption
|
||||
256
comfy/ldm/pixart/pixartms.py
Normal file
256
comfy/ldm/pixart/pixartms.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import (
|
||||
t2i_modulate,
|
||||
CaptionEmbedder,
|
||||
AttentionKVCompress,
|
||||
MultiHeadCrossAttention,
|
||||
T2IFinalLayer,
|
||||
SizeEmbedder,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||
indexing='ij'
|
||||
)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
|
||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArtMS(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=None,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
micro_condition=True,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.micro_conditioning = micro_condition
|
||||
if self.micro_conditioning:
|
||||
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# For fixed sin-cos embedding:
|
||||
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||
# self.base_size = input_size // self.patch_size
|
||||
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
if kv_compress_config is None:
|
||||
kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtMSBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
sampling=kv_compress_config['sampling'],
|
||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) conditioning
|
||||
ar: (N, 1): aspect ratio
|
||||
cs: (N ,2) size conditioning for height/width
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
c_res = (H + W) // 2
|
||||
pe_interpolation = self.pe_interpolation
|
||||
if pe_interpolation is None or self.pe_precision is not None:
|
||||
# calculate pe_interpolation on-the-fly
|
||||
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||
self.hidden_size,
|
||||
h=(H // self.patch_size),
|
||||
w=(W // self.patch_size),
|
||||
pe_interpolation=pe_interpolation,
|
||||
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
).unsqueeze(0)
|
||||
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||
|
||||
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||
bs = x.shape[0]
|
||||
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = None
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Fallback for missing microconds
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||
|
||||
## only return EPS
|
||||
if self.pred_sigma:
|
||||
return out[:, :self.in_channels]
|
||||
return out
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = h // self.patch_size
|
||||
w = w // self.patch_size
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
@@ -1,4 +1,5 @@
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
@@ -23,7 +24,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
logging.warning("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@@ -65,7 +66,7 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
@@ -133,7 +134,6 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
|
||||
368
comfy/lora.py
368
comfy/lora.py
@@ -16,8 +16,12 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
import logging
|
||||
import torch
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
@@ -29,7 +33,7 @@ LORA_CLIP_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
def load_lora(lora, to_load, log_missing=True):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@@ -45,10 +49,20 @@ def load_lora(lora, to_load):
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
@@ -68,6 +82,10 @@ def load_lora(lora, to_load):
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif mochi_lora in lora.keys():
|
||||
A_name = mochi_lora
|
||||
B_name = "{}.lora_A".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
@@ -78,7 +96,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@@ -189,17 +207,28 @@ def load_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
if log_missing:
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
|
||||
return patch_dict
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
clip_g_present = False
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
@@ -223,6 +252,7 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
clip_g_present = True
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
@@ -238,10 +268,18 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
t5_index = 1
|
||||
if clip_g_present:
|
||||
t5_index += 1
|
||||
if clip_l_present:
|
||||
t5_index += 1
|
||||
if t5_index == 2:
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||
t5_index += 1
|
||||
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
@@ -265,11 +303,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
@@ -277,6 +318,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
@@ -299,6 +341,9 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
@@ -307,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.PixArt):
|
||||
diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.HunyuanDiT):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
@@ -318,7 +377,286 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map[key_lora] = to
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
|
||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.HunyuanVideo):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
# diffusion-pipe lora format
|
||||
key_lora = k
|
||||
key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.")
|
||||
key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.")
|
||||
key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.")
|
||||
key_lora = key_lora[len("diffusion_model."):-len(".weight")]
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Pad a tensor to a new shape with zeros.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The original tensor to be padded.
|
||||
new_shape (List[int]): The desired shape of the padded tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||
|
||||
Note:
|
||||
If the new shape is smaller than the original tensor in any dimension,
|
||||
the original tensor will be truncated in that dimension.
|
||||
"""
|
||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||
|
||||
if len(new_shape) != len(tensor.shape):
|
||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||
|
||||
# Create a new tensor filled with zeros
|
||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
# Create slicing tuples for both tensors
|
||||
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
|
||||
# Copy the original tensor into the new tensor
|
||||
padded_tensor[new_slices] = tensor[orig_slices]
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
old_weight = None
|
||||
if offset is not None:
|
||||
old_weight = weight
|
||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
diff: torch.Tensor = v[0]
|
||||
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||
if do_pad_weight and diff.shape != weight.shape:
|
||||
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||
|
||||
if strength != 0.0:
|
||||
if diff.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
elif patch_type == "model_as_lora":
|
||||
target_weight: torch.Tensor = v[0]
|
||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
dora_scale = v[5]
|
||||
|
||||
old_glora = False
|
||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||
rank = v[0].shape[0]
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = v[1].shape[0]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / rank
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user