Compare commits
652 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
312d511630 | ||
|
|
4f4f1c642a | ||
|
|
010954d277 | ||
|
|
6d46bb4b4c | ||
|
|
67f57c5bcc | ||
|
|
fd943c928f | ||
|
|
d3bd983b91 | ||
|
|
fb4754624d | ||
|
|
180db6753f | ||
|
|
d062fcc5c0 | ||
|
|
456abad834 | ||
|
|
19e45e9b0e | ||
|
|
97f23b81f3 | ||
|
|
08b7cc7506 | ||
|
|
6c319cbb4e | ||
|
|
df1aebe52e | ||
|
|
704fc78854 | ||
|
|
1d9fee79fd | ||
|
|
aeba0b3a26 | ||
|
|
094306b626 | ||
|
|
31260f0275 | ||
|
|
f1c9ca816a | ||
|
|
f2289a1f59 | ||
|
|
fb83eda287 | ||
|
|
5e5e46d40c | ||
|
|
4eba3161cf | ||
|
|
592d056100 | ||
|
|
1c1687ab1c | ||
|
|
e6609dacde | ||
|
|
ba37e67964 | ||
|
|
06c661004e | ||
|
|
c9e1821a7b | ||
|
|
f58f0f5696 | ||
|
|
3a10b9641c | ||
|
|
89a84e32d2 | ||
|
|
e5799c4899 | ||
|
|
a0651359d7 | ||
|
|
ad3bd8aa49 | ||
|
|
5a87757ef9 | ||
|
|
464aece92b | ||
|
|
0b50d4c0db | ||
|
|
30b2eb8a93 | ||
|
|
f85c08df06 | ||
|
|
4202e956a0 | ||
|
|
b838c36720 | ||
|
|
fc39184ea9 | ||
|
|
ded60c33a0 | ||
|
|
8bb858e4d3 | ||
|
|
57893c843f | ||
|
|
65da29aaa9 | ||
|
|
10024a38ea | ||
|
|
87f9130778 | ||
|
|
7e84bf5373 | ||
|
|
4f3b50ba51 | ||
|
|
e930a387d6 | ||
|
|
d8e5662822 | ||
|
|
3d44a09812 | ||
|
|
62690eddec | ||
|
|
05eb10b43a | ||
|
|
f5e4e976f4 | ||
|
|
aee2908d03 | ||
|
|
dc46db7aa4 | ||
|
|
7046983d95 | ||
|
|
1c2d45d2b5 | ||
|
|
c820ef950d | ||
|
|
6a2e4bb9e0 | ||
|
|
f1f9763b4c | ||
|
|
08368f8e00 | ||
|
|
f3ff5c40db | ||
|
|
98ff01e148 | ||
|
|
bab836d88d | ||
|
|
4a9014e201 | ||
|
|
8a7c894d54 | ||
|
|
a814f2e8cc | ||
|
|
481732a0ed | ||
|
|
2156ce9453 | ||
|
|
4136502b7a | ||
|
|
9ad287ff20 | ||
|
|
f5cacaeb14 | ||
|
|
b7ed5f57bd | ||
|
|
b4abca828e | ||
|
|
158419f3a0 | ||
|
|
640c47e7de | ||
|
|
31e9e36c94 | ||
|
|
577de83ca9 | ||
|
|
3535909eb8 | ||
|
|
235d3901fc | ||
|
|
d42613686f | ||
|
|
1b3bf0a5da | ||
|
|
ae60b150e5 | ||
|
|
42da274717 | ||
|
|
28f178a840 | ||
|
|
8ab15c863c | ||
|
|
924d771e18 | ||
|
|
02a1b01aad | ||
|
|
a692c3cca4 | ||
|
|
5d3cc85e13 | ||
|
|
c7c025b8d1 | ||
|
|
fd08e39588 | ||
|
|
56b6ee6754 | ||
|
|
cc33cd3422 | ||
|
|
b9980592c4 | ||
|
|
16417b40d9 | ||
|
|
271c9c5b9e | ||
|
|
a4e679765e | ||
|
|
0cf2e46b17 | ||
|
|
094e9ef126 | ||
|
|
1271c4ef9d | ||
|
|
d9c80a85e5 | ||
|
|
3e62c5513a | ||
|
|
cd18582578 | ||
|
|
80a44b97f5 | ||
|
|
9187a09483 | ||
|
|
3041e5c354 | ||
|
|
7689917113 | ||
|
|
486ad8fdc5 | ||
|
|
065d855f14 | ||
|
|
530494588d | ||
|
|
2ab9618732 | ||
|
|
d9a87c1e6a | ||
|
|
551fe8dcee | ||
|
|
ff99861650 | ||
|
|
8d0661d0ba | ||
|
|
6d32dc049e | ||
|
|
aa9d759df3 | ||
|
|
c6c19e9980 | ||
|
|
08ff5fa08a | ||
|
|
4ca3d84277 | ||
|
|
39c27a3705 | ||
|
|
b1c7291569 | ||
|
|
dbc726f80c | ||
|
|
7ee96455e2 | ||
|
|
0a66d4b0af | ||
|
|
5c5457a4ef | ||
|
|
45503f6499 | ||
|
|
005a91ce2b | ||
|
|
68f0d35296 | ||
|
|
83d04717b6 | ||
|
|
7d329771f9 | ||
|
|
c15909bb62 | ||
|
|
772b4c5945 | ||
|
|
5a50c3c7e5 | ||
|
|
30159a7fe6 | ||
|
|
cb9ac3db58 | ||
|
|
8115a7895b | ||
|
|
c8cd7ad795 | ||
|
|
542b4b36b6 | ||
|
|
ac10a0d69e | ||
|
|
0dcc75ca54 | ||
|
|
b685b8a4e0 | ||
|
|
23e39f2ba7 | ||
|
|
78992c4b25 | ||
|
|
f935d42d8e | ||
|
|
a97f2f850a | ||
|
|
5acb705857 | ||
|
|
5c80da31db | ||
|
|
e2eed9eb9b | ||
|
|
11b68ebd22 | ||
|
|
188b383c35 | ||
|
|
2c1d686ec6 | ||
|
|
e8ddc2be95 | ||
|
|
dea1c7474a | ||
|
|
154f2911aa | ||
|
|
3eaad0590e | ||
|
|
7eaff81be1 | ||
|
|
21a11ef817 | ||
|
|
552615235d | ||
|
|
0738e4ea5d | ||
|
|
92cdc692f4 | ||
|
|
2d6805ce57 | ||
|
|
a8f63c0d5b | ||
|
|
454a635c1b | ||
|
|
966c43ce26 | ||
|
|
3ab231f01f | ||
|
|
1f3fba2af5 | ||
|
|
5d0d4ee98a | ||
|
|
9d57b8afd8 | ||
|
|
5d51794607 | ||
|
|
ce22f687cc | ||
|
|
b6fd3ffd10 | ||
|
|
11b72c9c55 | ||
|
|
2c735c13b4 | ||
|
|
fd27494441 | ||
|
|
f43e1d7f41 | ||
|
|
4486b0d0ff | ||
|
|
636d4bfb89 | ||
|
|
dc300a4569 | ||
|
|
f3b09b9f2d | ||
|
|
7ecd5e9614 | ||
|
|
2383a39e3b | ||
|
|
34e06bf7ec | ||
|
|
55822faa05 | ||
|
|
880c205df1 | ||
|
|
3dc240d089 | ||
|
|
19373aee75 | ||
|
|
93292bc450 | ||
|
|
05d5a75cdc | ||
|
|
eba7a25e7a | ||
|
|
dbcfd092a2 | ||
|
|
c14429940f | ||
|
|
0d720e4367 | ||
|
|
1fc00ba4b6 | ||
|
|
9899d187b1 | ||
|
|
f00f340a56 | ||
|
|
cce1d9145e | ||
|
|
b4dc03ad76 | ||
|
|
9ad792f927 | ||
|
|
6fc5dbd52a | ||
|
|
3e8155f7a3 | ||
|
|
8a438115fb | ||
|
|
a14c2fc356 | ||
|
|
9ee6ca99d8 | ||
|
|
bb495cc9b8 | ||
|
|
e51d9ba5fc | ||
|
|
c87a06f934 | ||
|
|
1714a4c158 | ||
|
|
73ecb75a3d | ||
|
|
22ad513c72 | ||
|
|
ed945a1790 | ||
|
|
f9207c6936 | ||
|
|
8ad7477647 | ||
|
|
98bdca4cb2 | ||
|
|
a26da20a76 | ||
|
|
e346d8584e | ||
|
|
ab31b64412 | ||
|
|
fe29739c68 | ||
|
|
e8345a9b7b | ||
|
|
8c6b9f4481 | ||
|
|
cc7e023a4a | ||
|
|
2f7d8159c3 | ||
|
|
70d7242e57 | ||
|
|
49b732afd5 | ||
|
|
3bfe4e5276 | ||
|
|
89e4ea0175 | ||
|
|
3a100b9a55 | ||
|
|
721253cb05 | ||
|
|
3d2e3a6f29 | ||
|
|
2222cf67fd | ||
|
|
ab5413351e | ||
|
|
2b71aab299 | ||
|
|
301e26b131 | ||
|
|
548457bac4 | ||
|
|
0b4584c741 | ||
|
|
a3100c8452 | ||
|
|
832fc02330 | ||
|
|
2d17d8910c | ||
|
|
a40fcfc2d5 | ||
|
|
0a1f8869c9 | ||
|
|
3661c833bc | ||
|
|
84fdaf7b0e | ||
|
|
8edc1f44c1 | ||
|
|
eade1551bb | ||
|
|
581a9991ff | ||
|
|
e471c726e5 | ||
|
|
75c1c757d9 | ||
|
|
ce9b084279 | ||
|
|
2206246055 | ||
|
|
d9fa9d307f | ||
|
|
83e839a89b | ||
|
|
0cf2274699 | ||
|
|
0956107170 | ||
|
|
a4a956dbbd | ||
|
|
8b9ce4ed18 | ||
|
|
3872b43d4b | ||
|
|
32ca0805b7 | ||
|
|
11f1b41bab | ||
|
|
3b19fc76e3 | ||
|
|
50614f1b79 | ||
|
|
6dc7b0bfe3 | ||
|
|
e8e990d6b8 | ||
|
|
2e24a15905 | ||
|
|
fd5297131f | ||
|
|
55a1b09ddc | ||
|
|
3c3988df45 | ||
|
|
7ebd8087ff | ||
|
|
c624c29d66 | ||
|
|
a2448fc527 | ||
|
|
6a0daa79b6 | ||
|
|
9c98c6358b | ||
|
|
7aceb9f91c | ||
|
|
35504e2f93 | ||
|
|
299436cfed | ||
|
|
52e566d2bc | ||
|
|
9b6cd9b874 | ||
|
|
3fc688aebd | ||
|
|
f4411250f3 | ||
|
|
d2a0fb6bb0 | ||
|
|
01015bff16 | ||
|
|
2330754b0e | ||
|
|
bc219a6487 | ||
|
|
94689766ad | ||
|
|
cfbe4b49ca | ||
|
|
ca8efab79f | ||
|
|
65ea778a5e | ||
|
|
db9f2a34fc | ||
|
|
7946049794 | ||
|
|
6f6349b6a7 | ||
|
|
1f138dd382 | ||
|
|
b779349b55 | ||
|
|
35e2dcf5d7 | ||
|
|
67c7184b74 | ||
|
|
6f8e766509 | ||
|
|
e1da98a14a | ||
|
|
a73410aafa | ||
|
|
9aac21f894 | ||
|
|
528d1b3563 | ||
|
|
2bc4b5968f | ||
|
|
7395b0c0d1 | ||
|
|
0952569493 | ||
|
|
29832b3b61 | ||
|
|
be4e760648 | ||
|
|
c3d9cc4592 | ||
|
|
84cc9cb528 | ||
|
|
ebbb920163 | ||
|
|
d60fe0af4a | ||
|
|
5dbd250965 | ||
|
|
4ab1875283 | ||
|
|
11b1f27cb1 | ||
|
|
70e15fd743 | ||
|
|
e1474150de | ||
|
|
e62d72e8ca | ||
|
|
1650cda030 | ||
|
|
a13125840c | ||
|
|
dfa36e6855 | ||
|
|
0124be4d93 | ||
|
|
29a70ca101 | ||
|
|
0bef826a98 | ||
|
|
85ef295069 | ||
|
|
5d84607bf3 | ||
|
|
c1909f350f | ||
|
|
52b3469606 | ||
|
|
889519971f | ||
|
|
76739c23c3 | ||
|
|
a80bc822a2 | ||
|
|
872780d236 | ||
|
|
6d45ffbe23 | ||
|
|
77633ba77d | ||
|
|
30e6cfb1a0 | ||
|
|
dc134b2fdb | ||
|
|
369b079ff6 | ||
|
|
9c9a7f012a | ||
|
|
93fedd92fe | ||
|
|
745b13649b | ||
|
|
2b140654c7 | ||
|
|
65042f7d39 | ||
|
|
7c7c70c400 | ||
|
|
8362199ee7 | ||
|
|
f86c724ef2 | ||
|
|
d6e5d487ad | ||
|
|
6752a826f6 | ||
|
|
04cf0ccb51 | ||
|
|
9af6320ec9 | ||
|
|
6f81cd8973 | ||
|
|
4dc6709307 | ||
|
|
4d55f16ae8 | ||
|
|
cf0b549d48 | ||
|
|
eb4543474b | ||
|
|
1804397952 | ||
|
|
f4dac8ab6f | ||
|
|
b07f116dea | ||
|
|
714f728820 | ||
|
|
92d8d15300 | ||
|
|
89253e9fe5 | ||
|
|
3ea3bc8546 | ||
|
|
8e69e2ddfd | ||
|
|
0270a0b41c | ||
|
|
26c7baf789 | ||
|
|
c37f15f98e | ||
|
|
4bca7367f3 | ||
|
|
b6fefe686b | ||
|
|
fa62287f1f | ||
|
|
0844998db3 | ||
|
|
4ced06b879 | ||
|
|
cb06e9669b | ||
|
|
0c32f82298 | ||
|
|
189da3726d | ||
|
|
9a66bb972d | ||
|
|
ea0f939df3 | ||
|
|
f37551c1d2 | ||
|
|
63023011b9 | ||
|
|
f40076096e | ||
|
|
96d891cb94 | ||
|
|
4553891bbd | ||
|
|
ace899e71a | ||
|
|
aff16532d4 | ||
|
|
b50ab153f9 | ||
|
|
072db3bea6 | ||
|
|
a6deca6d9a | ||
|
|
41c30e92e7 | ||
|
|
f579a740dd | ||
|
|
d37272532c | ||
|
|
12da6ef581 | ||
|
|
29d4384a75 | ||
|
|
c5be423d6b | ||
|
|
b4d3652d88 | ||
|
|
5715be2ca9 | ||
|
|
0d4d9222c6 | ||
|
|
afc85cdeb6 | ||
|
|
acc152b674 | ||
|
|
b07258cef2 | ||
|
|
31e54b7052 | ||
|
|
8c0bae50c3 | ||
|
|
530412cb9d | ||
|
|
61c8c70c6e | ||
|
|
d0399f4343 | ||
|
|
e2919d38b4 | ||
|
|
93c8607d51 | ||
|
|
b3d6ae15b3 | ||
|
|
2e21122aab | ||
|
|
1cd6cd6080 | ||
|
|
d7b4bf21a2 | ||
|
|
042a905c37 | ||
|
|
019c7029ea | ||
|
|
8773ccf74d | ||
|
|
1d5d6586f3 | ||
|
|
35740259de | ||
|
|
ab888e1e0b | ||
|
|
d9f0fcdb0c | ||
|
|
b124256817 | ||
|
|
af4b7c91be | ||
|
|
e57d2282d1 | ||
|
|
4027466c80 | ||
|
|
095d867147 | ||
|
|
caeb27c3a5 | ||
|
|
3d06e1c555 | ||
|
|
43a74c0de1 | ||
|
|
af93c8d1ee | ||
|
|
832e3f5ca3 | ||
|
|
079eccc92a | ||
|
|
b6951768c4 | ||
|
|
fca304debf | ||
|
|
14880e6dba | ||
|
|
f1059b0b82 | ||
|
|
debabccb84 | ||
|
|
37cd448529 | ||
|
|
94f21f9301 | ||
|
|
60653004e5 | ||
|
|
a57d635c5f | ||
|
|
016b219dcc | ||
|
|
8ac2dddeed | ||
|
|
3e880ac709 | ||
|
|
e5ea112a90 | ||
|
|
8d88bfaff9 | ||
|
|
ed4d92b721 | ||
|
|
932ae8d9ca | ||
|
|
44e19a28d3 | ||
|
|
0a0df5f136 | ||
|
|
24d6871e47 | ||
|
|
9e1d301129 | ||
|
|
768e035868 | ||
|
|
669e0497ea | ||
|
|
541dc08547 | ||
|
|
8d8dc9a262 | ||
|
|
2f98c24360 | ||
|
|
ef85058e97 | ||
|
|
f9230bd357 | ||
|
|
537c27cbf3 | ||
|
|
6ff2e4d550 | ||
|
|
222f48c0f2 | ||
|
|
13fd4d6e45 | ||
|
|
1210d094c7 | ||
|
|
255edf2246 | ||
|
|
4f011b9a00 | ||
|
|
67feb05299 | ||
|
|
6d21740346 | ||
|
|
7fbf4b72fe | ||
|
|
14ca5f5a10 | ||
|
|
ce557cfb88 | ||
|
|
96e2a45193 | ||
|
|
dfa2b6d129 | ||
|
|
f3566f0894 | ||
|
|
ca69b41cee | ||
|
|
a058f52090 | ||
|
|
d6bbe8c40f | ||
|
|
a7fe0a94de | ||
|
|
e857dd48b8 | ||
|
|
d303cb5341 | ||
|
|
fb2ad645a3 | ||
|
|
d8a7a32779 | ||
|
|
a00e1489d2 | ||
|
|
ebf038d4fa | ||
|
|
b4de04a1c1 | ||
|
|
b1a02131c9 | ||
|
|
3a3910f91d | ||
|
|
507199d9a8 | ||
|
|
2f3ab40b62 | ||
|
|
7fc3ccdcc2 | ||
|
|
55add50220 | ||
|
|
0aa2368e46 | ||
|
|
cca96a85ae | ||
|
|
619b8cde74 | ||
|
|
31831e6ef1 | ||
|
|
88ceb28e20 | ||
|
|
23289a6a5c | ||
|
|
9d8b6c1f46 | ||
|
|
6320d05696 | ||
|
|
25683b5b02 | ||
|
|
4758fb64b9 | ||
|
|
008761166f | ||
|
|
bfd5dfd611 | ||
|
|
55ade36d01 | ||
|
|
2e20e399ea | ||
|
|
3baf92d120 | ||
|
|
1709a8441e | ||
|
|
cba58fff0b | ||
|
|
2feb8d0b77 | ||
|
|
5b657f8c15 | ||
|
|
2cdbaf5169 | ||
|
|
c78a45685d | ||
|
|
3aaabb12d4 | ||
|
|
1f1c7b7b56 | ||
|
|
90f349f93d | ||
|
|
b9d9bcba14 | ||
|
|
42086af123 | ||
|
|
6c9bd11fa3 | ||
|
|
ee8a7ab69d | ||
|
|
9c773a241b | ||
|
|
adea2beb5c | ||
|
|
2ff3104f70 | ||
|
|
129d8908f7 | ||
|
|
ff838657fa | ||
|
|
2307ff6746 | ||
|
|
d0f3752e33 | ||
|
|
c515bdf371 | ||
|
|
4209edf48d | ||
|
|
d055325783 | ||
|
|
eeab420c70 | ||
|
|
916d1e14a9 | ||
|
|
c496e53519 | ||
|
|
7da85fac3f | ||
|
|
b65b83af6f | ||
|
|
c8a3492c22 | ||
|
|
5cbf79787f | ||
|
|
d45ebb63f6 | ||
|
|
caa6476a69 | ||
|
|
45671cda0b | ||
|
|
8f29664057 | ||
|
|
0b9839ef43 | ||
|
|
953693b137 | ||
|
|
a39ea87bca | ||
|
|
9e9c8a1c64 | ||
|
|
0f11d60afb | ||
|
|
79eea51a1d | ||
|
|
c0338a46a4 | ||
|
|
1c99734e5a | ||
|
|
67758f50f3 | ||
|
|
02eef72bf5 | ||
|
|
b7572b2f87 | ||
|
|
a90aafafc1 | ||
|
|
d9b7cfac7e | ||
|
|
3507870535 | ||
|
|
82ecb02c1e | ||
|
|
a618f768e0 | ||
|
|
e1dec3c792 | ||
|
|
96697c4bc5 | ||
|
|
b504bd606d | ||
|
|
d170292594 | ||
|
|
9cfd185676 | ||
|
|
4b5bcd8ac4 | ||
|
|
ceb50b2cbf | ||
|
|
160ca08138 | ||
|
|
c4bfdba330 | ||
|
|
ee9547ba31 | ||
|
|
19a64d6291 | ||
|
|
b486885e08 | ||
|
|
0229228f3f | ||
|
|
1ed75ab30e | ||
|
|
99a1fb6027 | ||
|
|
73e04987f7 | ||
|
|
5388df784a | ||
|
|
26e0ba8f8c | ||
|
|
bc6dac4327 | ||
|
|
f18ebbd316 | ||
|
|
15564688ed | ||
|
|
c6b9c11ef6 | ||
|
|
e44d0ac7f7 | ||
|
|
56bc64f351 | ||
|
|
f7d83b72e0 | ||
|
|
80f07952d2 | ||
|
|
57f330caf9 | ||
|
|
601ff9e3db | ||
|
|
341667c4d5 | ||
|
|
1419dee915 | ||
|
|
da13b6b827 | ||
|
|
c86cd58573 | ||
|
|
b5fe39211a | ||
|
|
e946667216 | ||
|
|
d7969cb070 | ||
|
|
bddb02660c | ||
|
|
418eb7062d | ||
|
|
cac68ca813 | ||
|
|
52c1d933b2 | ||
|
|
3cacd3fca5 | ||
|
|
2dda7c11a3 | ||
|
|
3ad3248ad7 | ||
|
|
c441048a4f | ||
|
|
9f4b181ab3 | ||
|
|
cbbf077593 | ||
|
|
0c04a6ae78 | ||
|
|
416ccc9e45 | ||
|
|
ff2ff02168 | ||
|
|
4c5c4ddeda | ||
|
|
79badea452 | ||
|
|
37e5390f5f | ||
|
|
a4f59bc65e | ||
|
|
ca457f7ba1 | ||
|
|
cd6f615038 | ||
|
|
517669aaa3 | ||
|
|
e4e1bff605 | ||
|
|
d6656b0c0c | ||
|
|
f4cdedea62 | ||
|
|
39b1fc4ccc | ||
|
|
0b25f47bd9 | ||
|
|
bda1482a27 | ||
|
|
19ee5d9d8b | ||
|
|
61b50720d0 | ||
|
|
0f954f34af | ||
|
|
5262901c5c | ||
|
|
cc550d5908 | ||
|
|
6d1a3f7d00 | ||
|
|
1b3a650f19 | ||
|
|
e83063bf24 | ||
|
|
558b7d8b22 | ||
|
|
caf2074773 | ||
|
|
bdf393792d | ||
|
|
4e14032c02 | ||
|
|
59d58b1158 | ||
|
|
563291ee51 | ||
|
|
6c0377f43e | ||
|
|
2cddbf0821 | ||
|
|
60749f345d | ||
|
|
d4426dce7c | ||
|
|
d9d7f3c619 | ||
|
|
fd5dfb812c | ||
|
|
3dfdddcc91 | ||
|
|
5747bc6457 | ||
|
|
5bea1d2ec9 | ||
|
|
5def9fbc83 | ||
|
|
7a7efe8424 | ||
|
|
44db978531 | ||
|
|
1c8d11e48a | ||
|
|
a220d11e6b | ||
|
|
23827ca312 | ||
|
|
0fd4e6c778 | ||
|
|
e2fafe0686 | ||
|
|
6579632201 | ||
|
|
ac2f0523ca | ||
|
|
fbf68c4e52 | ||
|
|
93477f8efe | ||
|
|
8af9a91e0c | ||
|
|
005d2d3a13 | ||
|
|
1e21f4c14e |
@@ -28,17 +28,17 @@ def pull(repo, remote_name='origin', branch='master'):
|
||||
|
||||
if repo.index.conflicts is not None:
|
||||
for conflict in repo.index.conflicts:
|
||||
print('Conflicts found in:', conflict[0].path)
|
||||
print('Conflicts found in:', conflict[0].path) # noqa: T201
|
||||
raise AssertionError('Conflicts, ahhhhh!!')
|
||||
|
||||
user = repo.default_signature
|
||||
tree = repo.index.write_tree()
|
||||
commit = repo.create_commit('HEAD',
|
||||
user,
|
||||
user,
|
||||
'Merge!',
|
||||
tree,
|
||||
[repo.head.target, remote_master_id])
|
||||
repo.create_commit('HEAD',
|
||||
user,
|
||||
user,
|
||||
'Merge!',
|
||||
tree,
|
||||
[repo.head.target, remote_master_id])
|
||||
# We need to do this or git CLI will think we are still merging.
|
||||
repo.state_cleanup()
|
||||
else:
|
||||
@@ -49,21 +49,26 @@ repo_path = str(sys.argv[1])
|
||||
repo = pygit2.Repository(repo_path)
|
||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||
try:
|
||||
print("stashing current changes")
|
||||
print("stashing current changes") # noqa: T201
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash")
|
||||
print("nothing to stash") # noqa: T201
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
||||
except:
|
||||
pass
|
||||
|
||||
print("checking out master branch")
|
||||
print("checking out master branch") # noqa: T201
|
||||
branch = repo.lookup_branch('master')
|
||||
if branch is None:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
try:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
except:
|
||||
print("pulling.") # noqa: T201
|
||||
pull(repo)
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
if branch is None:
|
||||
@@ -72,7 +77,7 @@ else:
|
||||
ref = repo.lookup_reference(branch.name)
|
||||
repo.checkout(ref)
|
||||
|
||||
print("pulling latest changes")
|
||||
print("pulling latest changes") # noqa: T201
|
||||
pull(repo)
|
||||
|
||||
if "--stable" in sys.argv:
|
||||
@@ -94,7 +99,7 @@ if "--stable" in sys.argv:
|
||||
if latest_tag is not None:
|
||||
repo.checkout(latest_tag)
|
||||
|
||||
print("Done!")
|
||||
print("Done!") # noqa: T201
|
||||
|
||||
self_update = True
|
||||
if len(sys.argv) > 2:
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
pause
|
||||
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -15,6 +15,14 @@ body:
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
label: Custom Node Testing
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@@ -11,6 +11,14 @@ body:
|
||||
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
label: Custom Node Testing
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
||||
@@ -3,8 +3,8 @@ name: Python Linting
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
ruff:
|
||||
name: Run Ruff
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -16,8 +16,8 @@ jobs:
|
||||
with:
|
||||
python-version: 3.x
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
- name: Install Ruff
|
||||
run: pip install ruff
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
|
||||
- name: Run Ruff
|
||||
run: ruff check .
|
||||
12
.github/workflows/stable-release.yml
vendored
12
.github/workflows/stable-release.yml
vendored
@@ -12,7 +12,7 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "124"
|
||||
default: "128"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
@@ -22,7 +22,7 @@ on:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "7"
|
||||
default: "10"
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.git_tag }}
|
||||
fetch-depth: 0
|
||||
fetch-depth: 150
|
||||
persist-credentials: false
|
||||
- uses: actions/cache/restore@v4
|
||||
id: cache
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
@@ -85,12 +85,14 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
||||
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
|
||||
4
.github/workflows/test-build.yml
vendored
4
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
@@ -28,4 +28,4 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements.txt
|
||||
|
||||
53
.github/workflows/test-ci.yml
vendored
53
.github/workflows/test-ci.yml
vendored
@@ -20,7 +20,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux, windows]
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
@@ -31,9 +32,9 @@ jobs:
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
# - os: windows
|
||||
# runner_label: [self-hosted, Windows]
|
||||
# flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
@@ -45,28 +46,28 @@ jobs:
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-win-nightly:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
# test-win-nightly:
|
||||
# strategy:
|
||||
# fail-fast: true
|
||||
# matrix:
|
||||
# os: [windows]
|
||||
# python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# cuda_version: ["12.1"]
|
||||
# torch_version: ["nightly"]
|
||||
# include:
|
||||
# - os: windows
|
||||
# runner_label: [self-hosted, Windows]
|
||||
# flags: ""
|
||||
# runs-on: ${{ matrix.runner_label }}
|
||||
# steps:
|
||||
# - name: Test Workflows
|
||||
# uses: comfy-org/comfy-action@main
|
||||
# with:
|
||||
# os: ${{ matrix.os }}
|
||||
# python_version: ${{ matrix.python_version }}
|
||||
# torch_version: ${{ matrix.torch_version }}
|
||||
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
# comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-unix-nightly:
|
||||
strategy:
|
||||
|
||||
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
2
.github/workflows/test-unit.yml
vendored
2
.github/workflows/test-unit.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
56
.github/workflows/update-api-stubs.yml
vendored
Normal file
56
.github/workflows/update-api-stubs.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Generate Pydantic Stubs from api.comfy.org
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-models:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install 'datamodel-code-generator[http]'
|
||||
npm install @redocly/cli
|
||||
|
||||
- name: Download OpenAPI spec
|
||||
run: |
|
||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||
|
||||
- name: Filter OpenAPI spec with Redocly
|
||||
run: |
|
||||
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||
|
||||
- name: Generate API models
|
||||
run: |
|
||||
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
- name: Check for changes
|
||||
id: git-check
|
||||
run: |
|
||||
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.git-check.outputs.changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v5
|
||||
with:
|
||||
commit-message: 'chore: update API models from OpenAPI spec'
|
||||
title: 'Update API models from api.comfy.org'
|
||||
body: |
|
||||
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
||||
|
||||
Generated automatically by the a Github workflow.
|
||||
branch: update-api-stubs
|
||||
delete-branch: true
|
||||
base: master
|
||||
58
.github/workflows/update-version.yml
vendored
Normal file
58
.github/workflows/update-version.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Update Version File
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "pyproject.toml"
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
runs-on: ubuntu-latest
|
||||
# Don't run on fork PRs
|
||||
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||
permissions:
|
||||
pull-requests: write
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Update comfyui_version.py
|
||||
run: |
|
||||
# Read version from pyproject.toml and update comfyui_version.py
|
||||
python -c '
|
||||
import tomllib
|
||||
|
||||
# Read version from pyproject.toml
|
||||
with open("pyproject.toml", "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
version = config["project"]["version"]
|
||||
|
||||
# Write version to comfyui_version.py
|
||||
with open("comfyui_version.py", "w") as f:
|
||||
f.write("# This file is automatically generated by the build process when version is\n")
|
||||
f.write("# updated in pyproject.toml.\n")
|
||||
f.write(f"__version__ = \"{version}\"\n")
|
||||
'
|
||||
|
||||
- name: Commit changes
|
||||
run: |
|
||||
git config --local user.name "github-actions"
|
||||
git config --local user.email "github-actions@github.com"
|
||||
git fetch origin ${{ github.head_ref }}
|
||||
git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
|
||||
git add comfyui_version.py
|
||||
git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
|
||||
git push origin HEAD:${{ github.head_ref }}
|
||||
@@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "124"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "7"
|
||||
default: "10"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -7,19 +7,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "124"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "4"
|
||||
default: "2"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-depth: 30
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||
|
||||
cd ComfyUI_windows_portable_nightly_pytorch
|
||||
|
||||
12
.github/workflows/windows_release_package.yml
vendored
12
.github/workflows/windows_release_package.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "124"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "7"
|
||||
default: "10"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-depth: 150
|
||||
persist-credentials: false
|
||||
- shell: bash
|
||||
run: |
|
||||
@@ -67,7 +67,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
@@ -82,12 +82,14 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
||||
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -21,3 +21,6 @@ venv/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
|
||||
25
CODEOWNERS
25
CODEOWNERS
@@ -1 +1,24 @@
|
||||
* @comfyanonymous
|
||||
# Admins
|
||||
* @comfyanonymous
|
||||
|
||||
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
|
||||
155
README.md
155
README.md
@@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular diffusion model GUI and backend.**
|
||||
**The most powerful and modular visual AI engine and application.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
@@ -31,17 +31,49 @@
|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
||||
|
||||
### [Installing ComfyUI](#installing)
|
||||
## Get Started
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
|
||||
#### [Windows Portable Package](#installing)
|
||||
- Get the latest commits and completely portable.
|
||||
- Available on Windows.
|
||||
|
||||
#### [Manual Install](#manual-install-windows-linux)
|
||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- Image Models
|
||||
- SD1.x, SD2.x,
|
||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||
- Pixart Alpha and Sigma
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- 3D Models
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@@ -61,9 +93,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
@@ -71,6 +100,22 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
- Builds a new release using the latest stable core version
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
- Weekly frontend updates are merged into the core repository
|
||||
- Features are frozen for the upcoming core release
|
||||
- Development continues for the next release cycle
|
||||
|
||||
## Shortcuts
|
||||
|
||||
| Keybind | Explanation |
|
||||
@@ -101,6 +146,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
| `Q` | Toggle visibility of the queue |
|
||||
| `H` | Toggle visibility of history |
|
||||
| `R` | Refresh graph |
|
||||
| `F` | Show/Hide menu |
|
||||
| `.` | Fit view to selection (Whole graph when nothing is selected) |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| `Shift` + Drag | Move multiple wires at once |
|
||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||
@@ -109,7 +156,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
# Installing
|
||||
|
||||
## Windows
|
||||
## Windows Portable
|
||||
|
||||
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
|
||||
|
||||
@@ -127,9 +174,18 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo
|
||||
|
||||
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
|
||||
|
||||
|
||||
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
||||
|
||||
You can install and start ComfyUI using comfy-cli:
|
||||
```bash
|
||||
pip install comfy-cli
|
||||
comfy install
|
||||
```
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@@ -141,21 +197,45 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch nightly, use the following command:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||
|
||||
```
|
||||
conda install libuv
|
||||
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||
```
|
||||
|
||||
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@@ -175,17 +255,6 @@ After this you should have everything installed and can proceed to running Comfy
|
||||
|
||||
### Others:
|
||||
|
||||
#### Intel GPUs
|
||||
|
||||
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
|
||||
|
||||
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
|
||||
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
|
||||
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
|
||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
|
||||
#### Apple Mac silicon
|
||||
|
||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||
@@ -201,6 +270,23 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
#### Ascend NPUs
|
||||
|
||||
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
|
||||
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
|
||||
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
|
||||
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
|
||||
|
||||
#### Cambricon MLUs
|
||||
|
||||
For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
|
||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
||||
3. Launch ComfyUI by running `python main.py`
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
@@ -215,7 +301,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
@@ -256,6 +342,8 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
||||
|
||||
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
||||
|
||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
@@ -272,7 +360,7 @@ For any bugs, issues, or feature requests related to the frontend, please use th
|
||||
|
||||
The new frontend is now the default for ComfyUI. However, please note:
|
||||
|
||||
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||
1. The frontend in the main ComfyUI repository is updated fortnightly.
|
||||
2. Daily releases are available in the separate frontend repository.
|
||||
|
||||
To use the most up-to-date frontend version:
|
||||
@@ -289,7 +377,7 @@ To use the most up-to-date frontend version:
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||
```
|
||||
|
||||
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
@@ -306,4 +394,3 @@ This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy
|
||||
### Which GPU should I buy for this?
|
||||
|
||||
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from folder_paths import folder_names_and_paths, get_directory_by_type
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
import os
|
||||
|
||||
class InternalRoutes:
|
||||
'''
|
||||
@@ -15,32 +15,16 @@ class InternalRoutes:
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
async def list_files(request):
|
||||
directory_key = request.query.get('directory', '')
|
||||
try:
|
||||
file_list = self.file_service.list_files(directory_key)
|
||||
return web.json_response({"files": file_list})
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
|
||||
@self.routes.get('/logs/raw')
|
||||
async def get_logs(request):
|
||||
async def get_raw_logs(request):
|
||||
self.terminal_service.update_size()
|
||||
return web.json_response({
|
||||
"entries": list(app.logger.get_logs()),
|
||||
@@ -67,6 +51,20 @@ class InternalRoutes:
|
||||
response[key] = folder_names_and_paths[key][0]
|
||||
return web.json_response(response)
|
||||
|
||||
@self.routes.get('/files/{directory_type}')
|
||||
async def get_files(request: web.Request) -> web.Response:
|
||||
directory_type = request.match_info['directory_type']
|
||||
if directory_type not in ("output", "input", "temp"):
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||
|
||||
class FileService:
|
||||
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||
|
||||
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||
if directory_key not in self.allowed_directories:
|
||||
raise ValueError("Invalid directory key")
|
||||
directory_path: str = self.allowed_directories[directory_key]
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
||||
@@ -25,10 +25,10 @@ class TerminalService:
|
||||
def update_size(self):
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
changed = True
|
||||
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
@@ -48,9 +48,9 @@ class TerminalService:
|
||||
def send_messages(self, entries):
|
||||
if not len(entries) or not len(self.subscriptions):
|
||||
return
|
||||
|
||||
|
||||
new_size = self.update_size()
|
||||
|
||||
|
||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||
if client_id not in self.server.sockets:
|
||||
# Automatically unsub if the socket has disconnected
|
||||
|
||||
@@ -39,4 +39,4 @@ class FileSystemOperations:
|
||||
"path": relative_path,
|
||||
"type": "directory"
|
||||
})
|
||||
return file_list
|
||||
return file_list
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
from aiohttp import web
|
||||
import logging
|
||||
|
||||
|
||||
class AppSettings():
|
||||
@@ -8,11 +9,21 @@ class AppSettings():
|
||||
self.user_manager = user_manager
|
||||
|
||||
def get_settings(self, request):
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request, "comfy.settings.json")
|
||||
try:
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request,
|
||||
"comfy.settings.json"
|
||||
)
|
||||
except KeyError as e:
|
||||
logging.error("User settings not found.")
|
||||
raise web.HTTPUnauthorized() from e
|
||||
if os.path.isfile(file):
|
||||
with open(file) as f:
|
||||
return json.load(f)
|
||||
try:
|
||||
with open(file) as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
logging.error(f"The user settings file is corrupted: {file}")
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@@ -51,4 +62,4 @@ class AppSettings():
|
||||
settings = self.get_settings(request)
|
||||
settings[setting_id] = await request.json()
|
||||
self.save_settings(request, settings)
|
||||
return web.Response(status=200)
|
||||
return web.Response(status=200)
|
||||
|
||||
145
app/custom_node_manager.py
Normal file
145
app/custom_node_manager.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import json
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from utils.json_util import merge_json_recursive
|
||||
|
||||
|
||||
# Extra locale files to load into main.json
|
||||
EXTRA_LOCALE_FILES = [
|
||||
"nodeDefs.json",
|
||||
"commands.json",
|
||||
"settings.json",
|
||||
]
|
||||
|
||||
|
||||
def safe_load_json_file(file_path: str) -> dict:
|
||||
if not os.path.exists(file_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Error loading {file_path}")
|
||||
return {}
|
||||
|
||||
|
||||
class CustomNodeManager:
|
||||
@lru_cache(maxsize=1)
|
||||
def build_translations(self):
|
||||
"""Load all custom nodes translations during initialization. Translations are
|
||||
expected to be loaded from `locales/` folder.
|
||||
|
||||
The folder structure is expected to be the following:
|
||||
- custom_nodes/
|
||||
- custom_node_1/
|
||||
- locales/
|
||||
- en/
|
||||
- main.json
|
||||
- commands.json
|
||||
- settings.json
|
||||
|
||||
returned translations are expected to be in the following format:
|
||||
{
|
||||
"en": {
|
||||
"nodeDefs": {...},
|
||||
"commands": {...},
|
||||
"settings": {...},
|
||||
...{other main.json keys}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
translations = {}
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
# Sort glob results for deterministic ordering
|
||||
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
|
||||
locales_dir = os.path.join(custom_node_dir, "locales")
|
||||
if not os.path.exists(locales_dir):
|
||||
continue
|
||||
|
||||
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
|
||||
lang_code = os.path.basename(os.path.dirname(lang_dir))
|
||||
|
||||
if lang_code not in translations:
|
||||
translations[lang_code] = {}
|
||||
|
||||
# Load main.json
|
||||
main_file = os.path.join(lang_dir, "main.json")
|
||||
node_translations = safe_load_json_file(main_file)
|
||||
|
||||
# Load extra locale files
|
||||
for extra_file in EXTRA_LOCALE_FILES:
|
||||
extra_file_path = os.path.join(lang_dir, extra_file)
|
||||
key = extra_file.split(".")[0]
|
||||
json_data = safe_load_json_file(extra_file_path)
|
||||
if json_data:
|
||||
node_translations[key] = json_data
|
||||
|
||||
if node_translations:
|
||||
translations[lang_code] = merge_json_recursive(
|
||||
translations[lang_code], node_translations
|
||||
)
|
||||
|
||||
return translations
|
||||
|
||||
def add_routes(self, routes, webapp, loadedModules):
|
||||
|
||||
example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
|
||||
|
||||
@routes.get("/workflow_templates")
|
||||
async def get_workflow_templates(request):
|
||||
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
||||
|
||||
files = []
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
for folder_name in example_workflow_folder_names:
|
||||
pattern = os.path.join(folder, f"*/{folder_name}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
files.extend(matched_files)
|
||||
|
||||
workflow_templates_dict = (
|
||||
{}
|
||||
) # custom_nodes folder name -> example workflow names
|
||||
for file in files:
|
||||
custom_nodes_name = os.path.basename(
|
||||
os.path.dirname(os.path.dirname(file))
|
||||
)
|
||||
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
||||
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
|
||||
workflow_name
|
||||
)
|
||||
return web.json_response(workflow_templates_dict)
|
||||
|
||||
# Serve workflow templates from custom nodes.
|
||||
for module_name, module_dir in loadedModules:
|
||||
for folder_name in example_workflow_folder_names:
|
||||
workflows_dir = os.path.join(module_dir, folder_name)
|
||||
|
||||
if os.path.exists(workflows_dir):
|
||||
if folder_name != "example_workflows":
|
||||
logging.debug(
|
||||
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
|
||||
folder_name, module_name)
|
||||
|
||||
webapp.add_routes(
|
||||
[
|
||||
web.static(
|
||||
"/api/workflow_templates/" + module_name, workflows_dir
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@routes.get("/i18n")
|
||||
async def get_i18n(request):
|
||||
"""Returns translations from all custom nodes' locales folders."""
|
||||
return web.json_response(self.build_translations())
|
||||
@@ -3,16 +3,69 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import zipfile
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {req_path}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||
""".strip()
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
________________________________________________________________________
|
||||
WARNING WARNING WARNING WARNING WARNING
|
||||
|
||||
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
________________________________________________________________________
|
||||
""".strip()
|
||||
)
|
||||
else:
|
||||
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check frontend version: {e}")
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
@@ -109,9 +162,62 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
|
||||
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-frontend-package is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
return str(
|
||||
importlib.resources.files(comfyui_workflow_templates) / "templates"
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-workflow-templates is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
"""Get the path to embedded documentation"""
|
||||
try:
|
||||
import comfyui_embedded_docs
|
||||
|
||||
return str(
|
||||
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
||||
)
|
||||
except ImportError:
|
||||
logging.info("comfyui-embedded-docs package not found")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
@@ -132,7 +238,9 @@ class FrontendManager:
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||
def init_frontend_unsafe(
|
||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||
) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
@@ -148,17 +256,26 @@ class FrontendManager:
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
|
||||
if version.startswith("v"):
|
||||
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||
expected_path = str(
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT)
|
||||
/ f"{repo_owner}_{repo_name}"
|
||||
/ version.lstrip("v")
|
||||
)
|
||||
if os.path.exists(expected_path):
|
||||
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||
logging.info(
|
||||
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
|
||||
)
|
||||
return expected_path
|
||||
|
||||
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||
logging.info(
|
||||
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
|
||||
)
|
||||
|
||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
@@ -201,4 +318,5 @@ class FrontendManager:
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
@@ -51,7 +51,7 @@ def on_flush(callback):
|
||||
if stderr_interceptor is not None:
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
@@ -70,4 +70,29 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
|
||||
if use_stdout:
|
||||
# Only errors and critical to stderr
|
||||
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
|
||||
|
||||
# Lesser to stdout
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
STARTUP_WARNINGS = []
|
||||
|
||||
|
||||
def log_startup_warning(msg):
|
||||
logging.warning(msg)
|
||||
STARTUP_WARNINGS.append(msg)
|
||||
|
||||
|
||||
def print_startup_warnings():
|
||||
for s in STARTUP_WARNINGS:
|
||||
logging.warning(s)
|
||||
STARTUP_WARNINGS.clear()
|
||||
|
||||
184
app/model_manager.py
Normal file
184
app/model_manager.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import folder_paths
|
||||
import glob
|
||||
import comfy.utils
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
||||
|
||||
|
||||
class ModelFileManager:
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||
|
||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||
return self.cache.get(key, default)
|
||||
|
||||
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
||||
self.cache[key] = value
|
||||
|
||||
def clear_cache(self):
|
||||
self.cache.clear()
|
||||
|
||||
def add_routes(self, routes):
|
||||
# NOTE: This is an experiment to replace `/models`
|
||||
@routes.get("/experiment/models")
|
||||
async def get_model_folders(request):
|
||||
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||
folder_black_list = ["configs", "custom_nodes"]
|
||||
output_folders: list[dict] = []
|
||||
for folder in model_types:
|
||||
if folder in folder_black_list:
|
||||
continue
|
||||
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
||||
return web.json_response(output_folders)
|
||||
|
||||
# NOTE: This is an experiment to replace `/models/{folder}`
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
|
||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||
async def get_model_preview(request):
|
||||
folder_name = request.match_info.get("folder", None)
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
folder = folders[0][path_index]
|
||||
full_filename = os.path.join(folder, filename)
|
||||
|
||||
previews = self.get_model_previews(full_filename)
|
||||
default_preview = previews[0] if len(previews) > 0 else None
|
||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||
return web.Response(status=404)
|
||||
|
||||
try:
|
||||
with Image.open(default_preview) as img:
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format="WEBP")
|
||||
img_bytes.seek(0)
|
||||
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
||||
except:
|
||||
return web.Response(status=404)
|
||||
|
||||
def get_model_file_list(self, folder_name: str):
|
||||
folder_name = map_legacy(folder_name)
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
output_list: list[dict] = []
|
||||
|
||||
for index, folder in enumerate(folders[0]):
|
||||
if not os.path.isdir(folder):
|
||||
continue
|
||||
out = self.cache_model_file_list_(folder)
|
||||
if out is None:
|
||||
out = self.recursive_search_models_(folder, index)
|
||||
self.set_cache(folder, out)
|
||||
output_list.extend(out[0])
|
||||
|
||||
return output_list
|
||||
|
||||
def cache_model_file_list_(self, folder: str):
|
||||
model_file_list_cache = self.get_cache(folder)
|
||||
|
||||
if model_file_list_cache is None:
|
||||
return None
|
||||
if not os.path.isdir(folder):
|
||||
return None
|
||||
if os.path.getmtime(folder) != model_file_list_cache[1]:
|
||||
return None
|
||||
for x in model_file_list_cache[1]:
|
||||
time_modified = model_file_list_cache[1][x]
|
||||
folder = x
|
||||
if os.path.getmtime(folder) != time_modified:
|
||||
return None
|
||||
|
||||
return model_file_list_cache
|
||||
|
||||
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}, time.perf_counter()
|
||||
|
||||
excluded_dir_names = [".git"]
|
||||
# TODO use settings
|
||||
include_hidden_files = False
|
||||
|
||||
result: list[str] = []
|
||||
dirs: dict[str, float] = {}
|
||||
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
if not include_hidden_files:
|
||||
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
||||
filenames = [f for f in filenames if not f.startswith(".")]
|
||||
|
||||
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
||||
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
path: str = os.path.join(dirpath, d)
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
return []
|
||||
|
||||
basename = os.path.splitext(filepath)[0]
|
||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||
image_files = filter_files_content_types(match_files, "image")
|
||||
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
||||
safetensors_metadata = {}
|
||||
|
||||
result: list[str | BytesIO] = []
|
||||
|
||||
for filename in image_files:
|
||||
_basename = os.path.splitext(filename)[0]
|
||||
if _basename == basename:
|
||||
result.append(filename)
|
||||
if _basename == f"{basename}.preview":
|
||||
result.append(filename)
|
||||
|
||||
if safetensors_file:
|
||||
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
||||
if header:
|
||||
safetensors_metadata = json.loads(header)
|
||||
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
||||
if safetensors_images:
|
||||
safetensors_images = json.loads(safetensors_images)
|
||||
for image in safetensors_images:
|
||||
result.append(BytesIO(base64.b64decode(image)))
|
||||
|
||||
return result
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.clear_cache()
|
||||
@@ -38,8 +38,8 @@ class UserManager():
|
||||
if not os.path.exists(user_directory):
|
||||
os.makedirs(user_directory, exist_ok=True)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(self.get_users_file()):
|
||||
@@ -197,6 +197,112 @@ class UserManager():
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@routes.get("/v2/userdata")
|
||||
async def list_userdata_v2(request):
|
||||
"""
|
||||
List files and directories in a user's data directory.
|
||||
|
||||
This endpoint provides a structured listing of contents within a specified
|
||||
subdirectory of the user's data storage.
|
||||
|
||||
Query Parameters:
|
||||
- path (optional): The relative path within the user's data directory
|
||||
to list. Defaults to the root ('').
|
||||
|
||||
Returns:
|
||||
- 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
|
||||
- 404: If the requested path does not exist.
|
||||
- 403: If the user is invalid.
|
||||
- 500: If there is an error reading the directory contents.
|
||||
- 200: JSON response containing a list of file and directory objects.
|
||||
Each object includes:
|
||||
- name: The name of the file or directory.
|
||||
- type: 'file' or 'directory'.
|
||||
- path: The relative path from the user's data root.
|
||||
- size (for files): The size in bytes.
|
||||
- modified (for files): The last modified timestamp (Unix epoch).
|
||||
"""
|
||||
requested_rel_path = request.rel_url.query.get('path', '')
|
||||
|
||||
# URL-decode the path parameter
|
||||
try:
|
||||
requested_rel_path = parse.unquote(requested_rel_path)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
|
||||
return web.Response(status=400, text="Invalid characters in path parameter")
|
||||
|
||||
|
||||
# Check user validity and get the absolute path for the requested directory
|
||||
try:
|
||||
base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
|
||||
|
||||
if requested_rel_path:
|
||||
target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
|
||||
else:
|
||||
target_abs_path = base_user_path
|
||||
|
||||
except KeyError as e:
|
||||
# Invalid user detected by get_request_user_id inside get_request_user_filepath
|
||||
logging.warning(f"Access denied for user: {e}")
|
||||
return web.Response(status=403, text="Invalid user specified in request")
|
||||
|
||||
|
||||
if not target_abs_path:
|
||||
# Path traversal or other issue detected by get_request_user_filepath
|
||||
return web.Response(status=400, text="Invalid path requested")
|
||||
|
||||
# Handle cases where the user directory or target path doesn't exist
|
||||
if not os.path.exists(target_abs_path):
|
||||
# Check if it's the base user directory that's missing (new user case)
|
||||
if target_abs_path == base_user_path:
|
||||
# It's okay if the base user directory doesn't exist yet, return empty list
|
||||
return web.json_response([])
|
||||
else:
|
||||
# A specific subdirectory was requested but doesn't exist
|
||||
return web.Response(status=404, text="Requested path not found")
|
||||
|
||||
if not os.path.isdir(target_abs_path):
|
||||
return web.Response(status=400, text="Requested path is not a directory")
|
||||
|
||||
results = []
|
||||
try:
|
||||
for root, dirs, files in os.walk(target_abs_path, topdown=True):
|
||||
# Process directories
|
||||
for dir_name in dirs:
|
||||
dir_path = os.path.join(root, dir_name)
|
||||
rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
|
||||
results.append({
|
||||
"name": dir_name,
|
||||
"path": rel_path,
|
||||
"type": "directory"
|
||||
})
|
||||
|
||||
# Process files
|
||||
for file_name in files:
|
||||
file_path = os.path.join(root, file_name)
|
||||
rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
|
||||
entry_info = {
|
||||
"name": file_name,
|
||||
"path": rel_path,
|
||||
"type": "file"
|
||||
}
|
||||
try:
|
||||
stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
|
||||
entry_info["size"] = stats.st_size
|
||||
entry_info["modified"] = stats.st_mtime
|
||||
except OSError as stat_error:
|
||||
logging.warning(f"Could not stat file {file_path}: {stat_error}")
|
||||
pass # Include file with available info
|
||||
results.append(entry_info)
|
||||
except OSError as e:
|
||||
logging.error(f"Error listing directory {target_abs_path}: {e}")
|
||||
return web.Response(status=500, text="Error reading directory contents")
|
||||
|
||||
# Sort results alphabetically, directories first then files
|
||||
results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
def get_user_data_path(request, check_exists = False, param = "file"):
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
|
||||
@@ -2,11 +2,9 @@
|
||||
#and modified
|
||||
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ..ldm.modules.diffusionmodules.util import (
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
@@ -162,7 +160,6 @@ class ControlNet(nn.Module):
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
@@ -415,7 +412,6 @@ class ControlNet(nn.Module):
|
||||
out_output = []
|
||||
out_middle = []
|
||||
|
||||
hs = []
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
|
||||
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import enum
|
||||
import os
|
||||
from typing import Optional
|
||||
import comfy.options
|
||||
|
||||
|
||||
@@ -43,10 +42,11 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific
|
||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||
|
||||
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
@@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
@@ -79,12 +80,15 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
|
||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
@@ -99,11 +103,14 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
@@ -120,14 +127,23 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||
|
||||
class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
@@ -135,10 +151,12 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
@@ -157,13 +175,14 @@ parser.add_argument(
|
||||
""",
|
||||
)
|
||||
|
||||
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
||||
"""Validate if the given path is a directory."""
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
def is_valid_directory(path: str) -> str:
|
||||
"""Validate if the given path is a directory, and check permissions."""
|
||||
if not os.path.exists(path):
|
||||
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
|
||||
if not os.path.isdir(path):
|
||||
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
|
||||
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
|
||||
if not os.access(path, os.R_OK):
|
||||
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
|
||||
return path
|
||||
|
||||
parser.add_argument(
|
||||
@@ -173,7 +192,16 @@ parser.add_argument(
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
|
||||
|
||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||
|
||||
parser.add_argument(
|
||||
"--comfy-api-base",
|
||||
type=str,
|
||||
default="https://api.comfy.org",
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
@@ -185,3 +213,17 @@ if args.windows_standalone_build:
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
if args.force_fp16:
|
||||
args.fp16_unet = True
|
||||
|
||||
|
||||
# '--fast' is not provided, use an empty set
|
||||
if args.fast is None:
|
||||
args.fast = set()
|
||||
# '--fast' is provided with an empty list, enable all optimizations
|
||||
elif args.fast == []:
|
||||
args.fast = set(PerformanceFeature)
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
@@ -97,14 +97,19 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
if embeds is not None:
|
||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
@@ -115,7 +120,10 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
if num_tokens is not None:
|
||||
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
|
||||
else:
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
@@ -203,6 +211,15 @@ class CLIPVision(torch.nn.Module):
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class LlavaProjector(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@@ -212,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
if "llava3" == config_dict.get("projector_type", None):
|
||||
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
|
||||
else:
|
||||
self.multi_modal_projector = None
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
out = self.visual_projection(x[2])
|
||||
return (x[0], x[1], out)
|
||||
projected = None
|
||||
if self.multi_modal_projector is not None:
|
||||
projected = self.multi_modal_projector(x[1])
|
||||
|
||||
return (x[0], x[1], out, projected)
|
||||
|
||||
@@ -9,6 +9,7 @@ import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@@ -17,6 +18,7 @@ class Output:
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
@@ -34,6 +36,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
@@ -42,10 +50,11 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
@@ -65,6 +74,7 @@ class ClipVisionModel():
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["mm_projected"] = out[3]
|
||||
return outputs
|
||||
|
||||
def convert_to_transformers(sd, prefix):
|
||||
@@ -101,12 +111,21 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"attention_dropout": 0.0,
|
||||
"dropout": 0.0,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 336,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"model_type": "clip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"projector_type": "llava3",
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
13
comfy/clip_vision_siglip_512.json
Normal file
13
comfy/clip_vision_siglip_512.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 512,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
@@ -5,7 +5,7 @@ This module provides type hinting and concrete convenience types for node develo
|
||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||
|
||||
```python
|
||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
@classmethod
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
|
||||
|
||||
|
||||
class UnetApplyFunction(Protocol):
|
||||
@@ -42,4 +42,5 @@ __all__ = [
|
||||
InputTypeDict.__name__,
|
||||
ComfyNodeABC.__name__,
|
||||
CheckLazyMixin.__name__,
|
||||
FileLocator.__name__,
|
||||
]
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
"""An example node that just adds 1 to an input integer.
|
||||
|
||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
||||
* Not intended for use in ComfyUI.
|
||||
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
|
||||
* This node is intended as an example for developers only.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, TypedDict, Optional
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
@@ -26,6 +27,7 @@ class IO(StrEnum):
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
COMBO = "COMBO"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
@@ -46,6 +48,7 @@ class IO(StrEnum):
|
||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
VIDEO = "VIDEO"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
@@ -67,90 +70,148 @@ class IO(StrEnum):
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class RemoteInputOptions(TypedDict):
|
||||
route: str
|
||||
"""The route to the remote source."""
|
||||
refresh_button: bool
|
||||
"""Specifies whether to show a refresh button in the UI below the widget."""
|
||||
control_after_refresh: Literal["first", "last"]
|
||||
"""Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
|
||||
timeout: int
|
||||
"""The maximum amount of time to wait for a response from the remote source in milliseconds."""
|
||||
max_retries: int
|
||||
"""The maximum number of retries before aborting the request."""
|
||||
refresh: int
|
||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||
|
||||
|
||||
class MultiSelectOptions(TypedDict):
|
||||
placeholder: NotRequired[str]
|
||||
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
||||
chip: NotRequired[bool]
|
||||
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
|
||||
"""
|
||||
|
||||
default: bool | str | float | int | list | tuple
|
||||
default: NotRequired[bool | str | float | int | list | tuple]
|
||||
"""The default value of the widget"""
|
||||
defaultInput: bool
|
||||
"""Defaults to an input slot rather than a widget"""
|
||||
forceInput: bool
|
||||
"""`defaultInput` and also don't allow converting to a widget"""
|
||||
lazy: bool
|
||||
defaultInput: NotRequired[bool]
|
||||
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
||||
- defaultInput on required inputs should be dropped.
|
||||
- defaultInput on optional inputs should be replaced with forceInput.
|
||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
||||
"""
|
||||
forceInput: NotRequired[bool]
|
||||
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
||||
lazy: NotRequired[bool]
|
||||
"""Declares that this input uses lazy evaluation"""
|
||||
rawLink: bool
|
||||
rawLink: NotRequired[bool]
|
||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||
tooltip: str
|
||||
tooltip: NotRequired[str]
|
||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||
socketless: NotRequired[bool]
|
||||
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
|
||||
Available from frontend v1.17.5
|
||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
|
||||
"""
|
||||
widgetType: NotRequired[str]
|
||||
"""Specifies a type to be used for widget initialization if different from the input type.
|
||||
Available from frontend v1.18.0
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: float
|
||||
min: NotRequired[float]
|
||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||
max: float
|
||||
max: NotRequired[float]
|
||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||
step: float
|
||||
step: NotRequired[float]
|
||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||
round: float
|
||||
round: NotRequired[float]
|
||||
"""Floats are rounded by this value (``FLOAT``)"""
|
||||
# class InputTypeBoolean(InputTypeOptions):
|
||||
# default: bool
|
||||
label_on: str
|
||||
label_on: NotRequired[str]
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
label_off: NotRequired[str]
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
multiline: bool
|
||||
multiline: NotRequired[bool]
|
||||
"""Use a multiline text box (``STRING``)"""
|
||||
placeholder: str
|
||||
placeholder: NotRequired[str]
|
||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||
# Deprecated:
|
||||
# defaultVal: str
|
||||
dynamicPrompts: bool
|
||||
dynamicPrompts: NotRequired[bool]
|
||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||
# class InputTypeCombo(InputTypeOptions):
|
||||
image_upload: NotRequired[bool]
|
||||
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
|
||||
image_folder: NotRequired[Literal["input", "output", "temp"]]
|
||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||
"""
|
||||
remote: NotRequired[RemoteInputOptions]
|
||||
"""Specifies the configuration for a remote input.
|
||||
Available after ComfyUI frontend v1.9.7
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||
control_after_generate: NotRequired[bool]
|
||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||
options: NotRequired[list[str | int | float]]
|
||||
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||
Prefer:
|
||||
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
||||
Over:
|
||||
[["Option 1", "Option 2", "Option 3"]]
|
||||
"""
|
||||
multi_select: NotRequired[MultiSelectOptions]
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||
|
||||
node_id: Literal["UNIQUE_ID"]
|
||||
node_id: NotRequired[Literal["UNIQUE_ID"]]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
unique_id: Literal["UNIQUE_ID"]
|
||||
unique_id: NotRequired[Literal["UNIQUE_ID"]]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
prompt: Literal["PROMPT"]
|
||||
prompt: NotRequired[Literal["PROMPT"]]
|
||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
||||
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
|
||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||
dynprompt: Literal["DYNPROMPT"]
|
||||
dynprompt: NotRequired[Literal["DYNPROMPT"]]
|
||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||
|
||||
|
||||
class InputTypeDict(TypedDict):
|
||||
"""Provides type hinting for node INPUT_TYPES.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
|
||||
"""
|
||||
|
||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
||||
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
||||
"""Describes all inputs that must be connected for the node to execute."""
|
||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
||||
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
||||
"""Describes inputs which do not need to be connected."""
|
||||
hidden: HiddenInputTypeDict
|
||||
hidden: NotRequired[HiddenInputTypeDict]
|
||||
"""Offers advanced functionality and server-client communication.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
||||
"""
|
||||
|
||||
|
||||
class ComfyNodeABC(ABC):
|
||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
|
||||
"""
|
||||
|
||||
DESCRIPTION: str
|
||||
@@ -167,12 +228,14 @@ class ComfyNodeABC(ABC):
|
||||
CATEGORY: str
|
||||
"""The category of the node, as per the "Add Node" menu.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
|
||||
"""
|
||||
EXPERIMENTAL: bool
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -181,9 +244,9 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
|
||||
"""
|
||||
return {"required": {}}
|
||||
|
||||
@@ -198,7 +261,7 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
||||
"""
|
||||
INPUT_IS_LIST: bool
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
@@ -209,9 +272,9 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
OUTPUT_IS_LIST: tuple[bool, ...]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
|
||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||
@@ -227,29 +290,29 @@ class ComfyNodeABC(ABC):
|
||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
||||
specifying which outputs which should be so treated.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
RETURN_TYPES: tuple[IO, ...]
|
||||
"""A tuple representing the outputs of this node.
|
||||
|
||||
Usage::
|
||||
|
||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
RETURN_NAMES: tuple[str, ...]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
OUTPUT_TOOLTIPS: tuple[str, ...]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
|
||||
"""
|
||||
|
||||
|
||||
@@ -267,8 +330,19 @@ class CheckLazyMixin:
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
return need
|
||||
|
||||
|
||||
class FileLocator(TypedDict):
|
||||
"""Provides type hinting for the file location"""
|
||||
|
||||
filename: str
|
||||
"""The filename of the file."""
|
||||
subfolder: str
|
||||
"""The subfolder of the file."""
|
||||
type: Literal["input", "output", "temp"]
|
||||
"""The root folder of the file."""
|
||||
|
||||
@@ -3,9 +3,6 @@ import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -27,6 +24,10 @@ class CONDRegular:
|
||||
conds.append(x.cond)
|
||||
return torch.cat(conds)
|
||||
|
||||
def size(self):
|
||||
return list(self.cond.size())
|
||||
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond
|
||||
@@ -46,7 +47,7 @@ class CONDCrossAttn(CONDRegular):
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
mult_min = math.lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
@@ -57,7 +58,7 @@ class CONDCrossAttn(CONDRegular):
|
||||
crossattn_max_len = self.cond.shape[1]
|
||||
for x in others:
|
||||
c = x.cond
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
|
||||
conds.append(c)
|
||||
|
||||
out = []
|
||||
@@ -67,6 +68,7 @@ class CONDCrossAttn(CONDRegular):
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
||||
|
||||
|
||||
class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -81,3 +83,48 @@ class CONDConstant(CONDRegular):
|
||||
|
||||
def concat(self, others):
|
||||
return self.cond
|
||||
|
||||
def size(self):
|
||||
return [1]
|
||||
|
||||
|
||||
class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
def can_concat(self, other):
|
||||
if len(self.cond) != len(other.cond):
|
||||
return False
|
||||
for i in range(len(self.cond)):
|
||||
if self.cond[i].shape != other.cond[i].shape:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
out = []
|
||||
for i in range(len(self.cond)):
|
||||
o = [self.cond[i]]
|
||||
for x in others:
|
||||
o.append(x.cond[i])
|
||||
out.append(torch.cat(o))
|
||||
|
||||
return out
|
||||
|
||||
def size(self): # hackish implementation to make the mem estimation work
|
||||
o = 0
|
||||
c = 1
|
||||
for c in self.cond:
|
||||
size = c.size()
|
||||
o += math.prod(size)
|
||||
if len(size) > 1:
|
||||
c = size[1]
|
||||
|
||||
return [1, c, o // c]
|
||||
|
||||
@@ -120,7 +120,7 @@ class ControlBase:
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@@ -297,7 +297,6 @@ class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
@@ -382,7 +381,6 @@ class ControlLora(ControlNet):
|
||||
self.control_model.to(comfy.model_management.get_torch_device())
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
weight = sd[k]
|
||||
@@ -420,10 +418,7 @@ def controlnet_config(sd, model_options={}):
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
@@ -691,10 +686,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
if supported_inference_dtypes is None:
|
||||
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
|
||||
@@ -744,6 +736,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||
model_options = model_options.copy()
|
||||
if "global_average_pooling" not in model_options:
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
@@ -823,7 +816,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
for i in range(4):
|
||||
for j in range(2):
|
||||
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
|
||||
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
|
||||
prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
|
||||
prefix_replace["adapter."] = ""
|
||||
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
|
||||
keys = t2i_data.keys()
|
||||
|
||||
@@ -4,105 +4,6 @@ import logging
|
||||
|
||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||
|
||||
# =================#
|
||||
# UNet Conversion #
|
||||
# =================#
|
||||
|
||||
unet_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
]
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0", "norm1"),
|
||||
("in_layers.2", "conv1"),
|
||||
("out_layers.0", "norm2"),
|
||||
("out_layers.3", "conv2"),
|
||||
("emb_layers.1", "time_emb_proj"),
|
||||
("skip_connection", "conv_shortcut"),
|
||||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
# hardcoded number of downblocks and resnets/attentions...
|
||||
# would need smarter logic for other networks.
|
||||
for i in range(4):
|
||||
# loop over downblocks/upblocks
|
||||
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
def convert_unet_state_dict(unet_state_dict):
|
||||
# buyer beware: this is a *brittle* function,
|
||||
# and correct output requires that all of these pieces interact in
|
||||
# the exact order in which I have arranged them.
|
||||
mapping = {k: k for k in unet_state_dict.keys()}
|
||||
for sd_name, hf_name in unet_conversion_map:
|
||||
mapping[hf_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in unet_conversion_map_layer:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
@@ -157,16 +58,23 @@ vae_conversion_map_attn = [
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
def reshape_weight_for_sd(w, conv3d=False):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if conv3d:
|
||||
return w.reshape(*w.shape, 1, 1, 1)
|
||||
else:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
conv3d = False
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in vae_conversion_map:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
if v.endswith(".conv.weight"):
|
||||
if not conv3d and vae_state_dict[k].ndim == 5:
|
||||
conv3d = True
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "attentions" in k:
|
||||
@@ -179,7 +87,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
logging.debug(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
@@ -206,6 +114,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
|
||||
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||
def cat_tensors(tensors):
|
||||
x = 0
|
||||
@@ -222,6 +131,7 @@ def cat_tensors(tensors):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
@@ -277,5 +187,3 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
|
||||
def convert_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#code taken from: https://github.com/wl-zhao/UniPC and modified
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import logging
|
||||
|
||||
from tqdm.auto import trange, tqdm
|
||||
from tqdm.auto import trange
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
@@ -80,7 +80,7 @@ class NoiseScheduleVP:
|
||||
'linear' or 'cosine' for continuous-time DPMs.
|
||||
Returns:
|
||||
A wrapper object of the forward SDE (VP type).
|
||||
|
||||
|
||||
===============================================================
|
||||
|
||||
Example:
|
||||
@@ -208,7 +208,7 @@ def model_wrapper(
|
||||
arXiv preprint arXiv:2202.00512 (2022).
|
||||
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||
arXiv preprint arXiv:2210.02303 (2022).
|
||||
|
||||
|
||||
4. "score": marginal score function. (Trained by denoising score matching).
|
||||
Note that the score function and the noise prediction model follows a simple relationship:
|
||||
```
|
||||
@@ -226,7 +226,7 @@ def model_wrapper(
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
``
|
||||
|
||||
The input `classifier_fn` has the following format:
|
||||
``
|
||||
@@ -240,12 +240,12 @@ def model_wrapper(
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
``
|
||||
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
||||
|
||||
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||
arXiv preprint arXiv:2207.12598 (2022).
|
||||
|
||||
|
||||
|
||||
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||
or continuous-time labels (i.e. epsilon to T).
|
||||
@@ -254,7 +254,7 @@ def model_wrapper(
|
||||
``
|
||||
def model_fn(x, t_continuous) -> noise:
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
return noise_pred(model, x, t_input, **model_kwargs)
|
||||
return noise_pred(model, x, t_input, **model_kwargs)
|
||||
``
|
||||
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
||||
|
||||
@@ -359,7 +359,7 @@ class UniPC:
|
||||
max_val=1.,
|
||||
variant='bh1',
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
"""Construct a UniPC.
|
||||
|
||||
We support both data_prediction and noise_prediction.
|
||||
"""
|
||||
@@ -372,7 +372,7 @@ class UniPC:
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
The dynamic thresholding method.
|
||||
The dynamic thresholding method.
|
||||
"""
|
||||
dims = x0.dim()
|
||||
p = self.dynamic_thresholding_ratio
|
||||
@@ -404,7 +404,7 @@ class UniPC:
|
||||
|
||||
def model_fn(self, x, t):
|
||||
"""
|
||||
Convert the model to the noise prediction model or the data prediction model.
|
||||
Convert the model to the noise prediction model or the data prediction model.
|
||||
"""
|
||||
if self.predict_x0:
|
||||
return self.data_prediction_fn(x, t)
|
||||
@@ -461,7 +461,7 @@ class UniPC:
|
||||
|
||||
def denoise_to_zero_fn(self, x, s):
|
||||
"""
|
||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||
"""
|
||||
return self.data_prediction_fn(x, s)
|
||||
|
||||
@@ -475,7 +475,7 @@ class UniPC:
|
||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
@@ -510,7 +510,7 @@ class UniPC:
|
||||
col = torch.ones_like(rks)
|
||||
for k in range(1, K + 1):
|
||||
C.append(col)
|
||||
col = col * rks / (k + 1)
|
||||
col = col * rks / (k + 1)
|
||||
C = torch.stack(C, dim=1)
|
||||
|
||||
if len(D1s) > 0:
|
||||
@@ -519,7 +519,6 @@ class UniPC:
|
||||
A_p = C_inv_p
|
||||
|
||||
if use_corrector:
|
||||
print('using corrector')
|
||||
C_inv = torch.linalg.inv(C)
|
||||
A_c = C_inv
|
||||
|
||||
@@ -622,12 +621,12 @@ class UniPC:
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= (i + 1)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=x.device)
|
||||
@@ -662,7 +661,7 @@ class UniPC:
|
||||
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||
@@ -670,7 +669,7 @@ class UniPC:
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
@@ -704,7 +703,6 @@ class UniPC:
|
||||
):
|
||||
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
# t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
device = x.device
|
||||
steps = len(timesteps) - 1
|
||||
if method == 'multistep':
|
||||
assert steps >= order
|
||||
@@ -872,4 +870,4 @@ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=F
|
||||
return x
|
||||
|
||||
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
|
||||
451
comfy/hooks.py
451
comfy/hooks.py
@@ -5,6 +5,7 @@ import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||
@@ -15,130 +16,171 @@ import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
# Hooks explanation
|
||||
# -------------------
|
||||
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
|
||||
# make explicit special cases like it does for ControlNet and GLIGEN.
|
||||
#
|
||||
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
|
||||
# that should run special code when a 'marked' cond is used in sampling.
|
||||
# #######################################################################################################
|
||||
|
||||
class EnumHookMode(enum.Enum):
|
||||
'''
|
||||
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
|
||||
|
||||
MinVram: No caching will occur for any operations related to hooks.
|
||||
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
|
||||
'''
|
||||
MinVram = "minvram"
|
||||
MaxSpeed = "maxspeed"
|
||||
|
||||
class EnumHookType(enum.Enum):
|
||||
'''
|
||||
Hook types, each of which has different expected behavior.
|
||||
'''
|
||||
Weight = "weight"
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Callbacks = "callbacks"
|
||||
Wrappers = "wrappers"
|
||||
SetInjections = "add_injections"
|
||||
AdditionalModels = "add_models"
|
||||
TransformerOptions = "transformer_options"
|
||||
Injections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
Model = "model"
|
||||
Clip = "clip"
|
||||
|
||||
class EnumHookScope(enum.Enum):
|
||||
'''
|
||||
Determines if hook should be limited in its influence over sampling.
|
||||
|
||||
AllConditioning: hook will affect all conds used in sampling.
|
||||
HookedOnly: hook will only affect the conds it was attached to.
|
||||
'''
|
||||
AllConditioning = "all_conditioning"
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
|
||||
# NOTE: this is an example of how the should_register function should look
|
||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
'''Example for how custom_should_register function can look like.'''
|
||||
return True
|
||||
|
||||
|
||||
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
|
||||
'''Creates base dictionary for use with Hooks' target param.'''
|
||||
d = {}
|
||||
if target is not None:
|
||||
d['target'] = target
|
||||
d.update(kwargs)
|
||||
return d
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
||||
self.hook_type = hook_type
|
||||
'''Enum identifying the general class of this hook.'''
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
|
||||
self.hook_id = hook_id
|
||||
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
|
||||
self.hook_scope = hook_scope
|
||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
def initialize_timesteps(self, model: BaseModel):
|
||||
self.reset()
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: Hook = subtype()
|
||||
def clone(self):
|
||||
c: Hook = self.__class__()
|
||||
c.hook_type = self.hook_type
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.hook_scope = self.hook_scope
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return self.custom_should_register(self, model, model_options, target, registered)
|
||||
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: 'Hook'):
|
||||
def __eq__(self, other: Hook):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hook_ref)
|
||||
|
||||
class WeightHook(Hook):
|
||||
'''
|
||||
Hook responsible for tracking weights to be applied to some model/clip.
|
||||
|
||||
Note, value of hook_scope is ignored and is treated as HookedOnly.
|
||||
'''
|
||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||
super().__init__(hook_type=EnumHookType.Weight)
|
||||
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
|
||||
self.weights: dict = None
|
||||
self.weights_clip: dict = None
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
|
||||
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
return self._strength_model * self.strength
|
||||
|
||||
|
||||
@property
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
weights = None
|
||||
if target == EnumWeightTarget.Model:
|
||||
strength = self._strength_model
|
||||
else:
|
||||
|
||||
target = target_dict.get('target', None)
|
||||
if target == EnumWeightTarget.Clip:
|
||||
strength = self._strength_clip
|
||||
|
||||
else:
|
||||
strength = self._strength_model
|
||||
|
||||
if self.need_weight_init:
|
||||
key_map = {}
|
||||
if target == EnumWeightTarget.Model:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Model:
|
||||
weights = self.weights
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
weights = self.weights_clip
|
||||
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
else:
|
||||
weights = self.weights
|
||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.add(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WeightHook = super().clone(subtype)
|
||||
def clone(self):
|
||||
c: WeightHook = super().clone()
|
||||
c.weights = self.weights
|
||||
c.weights_clip = self.weights_clip
|
||||
c.need_weight_init = self.need_weight_init
|
||||
@@ -146,127 +188,158 @@ class WeightHook(Hook):
|
||||
c._strength_clip = self._strength_clip
|
||||
return c
|
||||
|
||||
class PatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.Patch)
|
||||
self.patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: PatchHook = super().clone(subtype)
|
||||
c.patches = self.patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self):
|
||||
def __init__(self, object_patches: dict[str]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: ObjectPatchHook = super().clone(subtype)
|
||||
self.object_patches = object_patches
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self):
|
||||
c: ObjectPatchHook = super().clone()
|
||||
c.object_patches = self.object_patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class AddModelsHook(Hook):
|
||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
|
||||
|
||||
class AdditionalModelsHook(Hook):
|
||||
'''
|
||||
Hook responsible for telling model management any additional models that should be loaded.
|
||||
|
||||
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
||||
'''
|
||||
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
||||
super().__init__(hook_type=EnumHookType.AdditionalModels)
|
||||
self.models = models
|
||||
self.append_when_same = True
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddModelsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
def clone(self):
|
||||
c: AdditionalModelsHook = super().clone()
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
registered.add(self)
|
||||
return True
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||
class TransformerOptionsHook(Hook):
|
||||
'''
|
||||
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||
'''
|
||||
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||
self.transformers_dict = transformers_dict
|
||||
self.hook_scope = hook_scope
|
||||
self._skip_adding = False
|
||||
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
|
||||
|
||||
def clone(self):
|
||||
c: TransformerOptionsHook = super().clone()
|
||||
c.transformers_dict = self.transformers_dict
|
||||
c._skip_adding = self._skip_adding
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||
self._skip_adding = False
|
||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||
add_model_options = {"transformer_options": self.transformers_dict,
|
||||
"to_load_options": self.transformers_dict}
|
||||
# skip_adding if included in AllConditioning to avoid double loading
|
||||
self._skip_adding = True
|
||||
else:
|
||||
add_model_options = {"to_load_options": self.transformers_dict}
|
||||
registered.add(self)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
return True
|
||||
|
||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
if not self._skip_adding:
|
||||
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||
|
||||
WrapperHook = TransformerOptionsHook
|
||||
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||
|
||||
class InjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.Injections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: SetInjectionsHook = super().clone(subtype)
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self):
|
||||
c: InjectionsHook = super().clone()
|
||||
c.key = self.key
|
||||
c.injections = self.injections.copy() if self.injections else self.injections
|
||||
return c
|
||||
|
||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
# TODO: add functionality
|
||||
pass
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
|
||||
|
||||
class HookGroup:
|
||||
'''
|
||||
Stores groups of hooks, and allows them to be queried by type.
|
||||
|
||||
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
|
||||
always use the provided functions on HookGroup.
|
||||
'''
|
||||
def __init__(self):
|
||||
self.hooks: list[Hook] = []
|
||||
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.hooks)
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
|
||||
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
|
||||
|
||||
def remove(self, hook: Hook):
|
||||
if hook in self.hooks:
|
||||
self.hooks.remove(hook)
|
||||
self._hook_dict[hook.hook_type].remove(hook)
|
||||
|
||||
def get_type(self, hook_type: EnumHookType):
|
||||
return self._hook_dict.get(hook_type, [])
|
||||
|
||||
def contains(self, hook: Hook):
|
||||
return hook in self.hooks
|
||||
|
||||
|
||||
def is_subset_of(self, other: HookGroup):
|
||||
self_hooks = set(self.hooks)
|
||||
other_hooks = set(other.hooks)
|
||||
return self_hooks.issubset(other_hooks)
|
||||
|
||||
def new_with_common_hooks(self, other: HookGroup):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
if other.contains(hook):
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone(self):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone_and_combine(self, other: 'HookGroup'):
|
||||
def clone_and_combine(self, other: HookGroup):
|
||||
c = self.clone()
|
||||
if other is not None:
|
||||
for hook in other.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
@@ -274,36 +347,29 @@ class HookGroup:
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
def get_dict_repr(self):
|
||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||
for hook in self.hooks:
|
||||
with_type = d.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
# only care about WeightHooks, for now
|
||||
for hook in self.get_type(EnumHookType.Weight):
|
||||
hook: WeightHook
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
# hooks should not have their schedules in a list of tuples
|
||||
all_ranges: list[tuple[float, float]] = []
|
||||
for range_kfs in scheduled_hooks.values():
|
||||
@@ -335,7 +401,7 @@ class HookGroup:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
@@ -364,10 +430,16 @@ class HookKeyframe:
|
||||
self.start_percent = float(start_percent)
|
||||
self.start_t = 999999999.9
|
||||
self.guarantee_steps = guarantee_steps
|
||||
|
||||
|
||||
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
|
||||
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
|
||||
if self.start_t > max_sigma:
|
||||
return 0
|
||||
return self.guarantee_steps
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframe(strength=self.strength,
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
c.start_t = self.start_t
|
||||
return c
|
||||
|
||||
@@ -394,7 +466,7 @@ class HookKeyframeGroup:
|
||||
self._current_strength = None
|
||||
self.curr_t = -1.
|
||||
self._set_first_as_current()
|
||||
|
||||
|
||||
def add(self, keyframe: HookKeyframe):
|
||||
# add to end of list, then sort
|
||||
self.keyframes.append(keyframe)
|
||||
@@ -406,33 +478,40 @@ class HookKeyframeGroup:
|
||||
self._current_keyframe = self.keyframes[0]
|
||||
else:
|
||||
self._current_keyframe = None
|
||||
|
||||
|
||||
def has_guarantee_steps(self):
|
||||
for kf in self.keyframes:
|
||||
if kf.guarantee_steps > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def has_index(self, index: int):
|
||||
return index >= 0 and index < len(self.keyframes)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.keyframes) == 0
|
||||
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframeGroup()
|
||||
for keyframe in self.keyframes:
|
||||
c.keyframes.append(keyframe.clone())
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
|
||||
def initialize_timesteps(self, model: BaseModel):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
||||
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
|
||||
if self.is_empty():
|
||||
return False
|
||||
if curr_t == self._curr_t:
|
||||
return False
|
||||
max_sigma = torch.max(transformer_options["sample_sigmas"])
|
||||
prev_index = self._current_index
|
||||
prev_strength = self._current_strength
|
||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
||||
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
|
||||
# if has next index, loop through and see if need to switch
|
||||
if self.has_index(self._current_index+1):
|
||||
for i in range(self._current_index+1, len(self.keyframes)):
|
||||
@@ -445,7 +524,7 @@ class HookKeyframeGroup:
|
||||
self._current_keyframe = eval_c
|
||||
self._current_used_steps = 0
|
||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||
if self._current_keyframe.guarantee_steps > 0:
|
||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
@@ -508,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
||||
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
||||
if hooks is None or model.is_clip:
|
||||
return {}
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
||||
hook: TransformerOptionsHook
|
||||
hook.on_apply_hooks(model, transformer_options)
|
||||
return transformer_options
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
@@ -534,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
|
||||
hook.need_weight_init = False
|
||||
return hook_group
|
||||
|
||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
|
||||
if model is None:
|
||||
return None
|
||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||
@@ -546,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
|
||||
return patches_model
|
||||
|
||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
|
||||
strength_model: float, strength_clip: float):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@@ -564,7 +654,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||
@@ -575,7 +665,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
print(f"NOT LOADED {x}")
|
||||
logging.warning(f"NOT LOADED {x}")
|
||||
return (new_modelpatcher, new_clip, hook_group)
|
||||
|
||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||
@@ -598,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
|
||||
else:
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
|
||||
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
if cache is None:
|
||||
cache = {}
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
if append_hooks and k == 'hooks':
|
||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||
_combine_hooks_from_values(n[1], values, cache)
|
||||
else:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||
if hooks is None:
|
||||
return cond
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
|
||||
|
||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||
if timestep_range is None:
|
||||
@@ -650,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
|
||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
final_conds = []
|
||||
cache = {}
|
||||
for c in conds:
|
||||
# first, apply lora_hook to conditioning, if provided
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, apply mask to conditioning
|
||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||
# apply timesteps, if present
|
||||
@@ -664,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
cache = {}
|
||||
for c, masked_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, apply mask to new conditioning, if provided
|
||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||
# apply timesteps, if present
|
||||
@@ -678,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
|
||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
cache = {}
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||
new_c = conditioning_set_values(new_c, {'default': True})
|
||||
# apply timesteps, if present
|
||||
|
||||
141
comfy/image_encoders/dino2.py
Normal file
141
comfy/image_encoders/dino2.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dense(x)
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
in_features = out_features = dim
|
||||
hidden_features = int(dim * 4)
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
|
||||
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.weights_in(x)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = torch.nn.functional.silu(x1) * x2
|
||||
return self.weights_out(x)
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.projection = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
return x
|
||||
|
||||
|
||||
class Dinov2Model(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
pooled_output = x[:, 0, :]
|
||||
return x, i, pooled_output, None
|
||||
21
comfy/image_encoders/dino2_giant.json
Normal file
21
comfy/image_encoders/dino2_giant.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_size": 1536,
|
||||
"image_size": 518,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"layerscale_value": 1.0,
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_attention_heads": 24,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 40,
|
||||
"patch_size": 14,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": true,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import numpy as np
|
||||
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
|
||||
|
||||
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
|
||||
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
||||
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
||||
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
|
||||
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
|
||||
|
||||
@@ -40,7 +40,7 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
t = torch.linspace(1, eps_s, n, device=device)
|
||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
||||
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
@@ -70,8 +70,14 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def default_noise_sampler(x):
|
||||
return lambda sigma, sigma_next: torch.randn_like(x)
|
||||
def default_noise_sampler(x, seed=None):
|
||||
if seed is not None:
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
@@ -168,7 +174,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -189,7 +196,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -290,7 +298,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -318,7 +327,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -465,7 +475,7 @@ class DPMSolver(nn.Module):
|
||||
return x_3, eps_cache
|
||||
|
||||
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||
if not t_end > t_start and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
|
||||
@@ -504,7 +514,7 @@ class DPMSolver(nn.Module):
|
||||
return x
|
||||
|
||||
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||
if order not in {2, 3}:
|
||||
raise ValueError('order should be 2 or 3')
|
||||
forward = t_end > t_start
|
||||
@@ -591,7 +601,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
@@ -625,7 +636,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||
@@ -676,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
@@ -750,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
@@ -796,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
@@ -846,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
@@ -855,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
@@ -864,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||
@@ -882,7 +894,8 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||
|
||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
@@ -902,7 +915,8 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
@torch.no_grad()
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -1153,7 +1167,8 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
@@ -1179,7 +1194,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
@@ -1230,7 +1246,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
@@ -1249,3 +1265,284 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||
old_uncond_denoised = uncond_denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
phi1_fn = lambda t: torch.expm1(t) / t
|
||||
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||
|
||||
old_sigma_down = None
|
||||
old_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
if cfg_pp:
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
if sigma_down == 0 or old_denoised is None:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
d = to_d(x, sigmas[i], uncond_denoised)
|
||||
x = denoised + d * sigma_down
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
h = t_next - t
|
||||
c2 = (t_prev - t_old) / h
|
||||
|
||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
|
||||
|
||||
if cfg_pp:
|
||||
x = x + (denoised - uncond_denoised)
|
||||
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
|
||||
else:
|
||||
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
if cfg_pp:
|
||||
old_denoised = uncond_denoised
|
||||
else:
|
||||
old_denoised = denoised
|
||||
old_sigma_down = sigma_down
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
old_d = None
|
||||
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
if cfg_pp:
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if cfg_pp:
|
||||
d = to_d(x, sigmas[i], uncond_denoised)
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
if i == 0:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
else:
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Gradient estimation
|
||||
if cfg_pp:
|
||||
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||
else:
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
'''
|
||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s = t + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s = s.neg().exp()
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
'''
|
||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s_1 = t + r_1 * h
|
||||
s_2 = t + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
return x
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_dimensions = 2
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
taesd_decoder_name = None
|
||||
@@ -143,6 +144,7 @@ class SD3(LatentFormat):
|
||||
|
||||
class StableAudio1(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Flux(SD3):
|
||||
latent_channels = 16
|
||||
@@ -178,6 +180,7 @@ class Flux(SD3):
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
@@ -219,6 +222,8 @@ class Mochi(LatentFormat):
|
||||
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||
@@ -352,3 +357,116 @@ class LTXV(LatentFormat):
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
scale_factor = 0.476986
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
[ 0.0696, 0.0795, 0.0518],
|
||||
[ 0.0135, -0.0945, -0.0282],
|
||||
[ 0.0108, -0.0250, -0.0765],
|
||||
[-0.0209, 0.0032, 0.0224],
|
||||
[-0.0804, -0.0254, -0.0639],
|
||||
[-0.0991, 0.0271, -0.0669],
|
||||
[-0.0646, -0.0422, -0.0400],
|
||||
[-0.0696, -0.0595, -0.0894],
|
||||
[-0.0799, -0.0208, -0.0375],
|
||||
[ 0.1166, 0.1627, 0.0962],
|
||||
[ 0.1165, 0.0432, 0.0407],
|
||||
[-0.2315, -0.1920, -0.1355],
|
||||
[-0.0270, 0.0401, -0.0821],
|
||||
[-0.0616, -0.0997, -0.0727],
|
||||
[ 0.0249, -0.0469, -0.1703]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.1817, 0.2284, 0.2423],
|
||||
[-0.0586, -0.0862, -0.3108],
|
||||
[-0.4703, -0.4255, -0.3995],
|
||||
[ 0.0803, 0.1963, 0.1001],
|
||||
[-0.0820, -0.1050, 0.0400],
|
||||
[ 0.2511, 0.3098, 0.2787],
|
||||
[-0.1830, -0.2117, -0.0040],
|
||||
[-0.0621, -0.2187, -0.0939],
|
||||
[ 0.3619, 0.1082, 0.1455],
|
||||
[ 0.3164, 0.3922, 0.2575],
|
||||
[ 0.1152, 0.0231, -0.0462],
|
||||
[-0.1434, -0.3609, -0.3665],
|
||||
[ 0.0635, 0.1471, 0.1680],
|
||||
[-0.3635, -0.1963, -0.3248],
|
||||
[-0.1865, 0.0365, 0.2346],
|
||||
[ 0.0447, 0.0994, 0.0881]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
|
||||
|
||||
class Wan21(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.1299, -0.1692, 0.2932],
|
||||
[ 0.0671, 0.0406, 0.0442],
|
||||
[ 0.3568, 0.2548, 0.1747],
|
||||
[ 0.0372, 0.2344, 0.1420],
|
||||
[ 0.0313, 0.0189, -0.0328],
|
||||
[ 0.0296, -0.0956, -0.0665],
|
||||
[-0.3477, -0.4059, -0.2925],
|
||||
[ 0.0166, 0.1902, 0.1975],
|
||||
[-0.0412, 0.0267, -0.1364],
|
||||
[-0.1293, 0.0740, 0.1636],
|
||||
[ 0.0680, 0.3019, 0.1128],
|
||||
[ 0.0032, 0.0581, 0.0639],
|
||||
[-0.1251, 0.0927, 0.1699],
|
||||
[ 0.0060, -0.0633, 0.0005],
|
||||
[ 0.3477, 0.2275, 0.2950],
|
||||
[ 0.1984, 0.0913, 0.1861]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
761
comfy/ldm/ace/attention.py
Normal file
761
comfy/ldm/ace/attention.py
Normal file
@@ -0,0 +1,761 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
kv_heads: Optional[int] = None,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
processor=None,
|
||||
out_dim: int = None,
|
||||
out_context_dim: int = None,
|
||||
context_pre_only=None,
|
||||
pre_only=False,
|
||||
elementwise_affine: bool = True,
|
||||
is_causal: bool = False,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.is_causal = is_causal
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||
raise ValueError(
|
||||
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||
)
|
||||
|
||||
self.group_norm = None
|
||||
self.spatial_norm = None
|
||||
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
|
||||
self.norm_cross = None
|
||||
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
self.added_proj_bias = added_proj_bias
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.add_q_proj = None
|
||||
self.add_k_proj = None
|
||||
self.add_v_proj = None
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
else:
|
||||
self.to_out = None
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = None
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CustomLiteLAProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
|
||||
|
||||
def __init__(self):
|
||||
self.kernel_func = nn.ReLU(inplace=False)
|
||||
self.eps = 1e-15
|
||||
self.pad_val = 1.0
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states_len = hidden_states.shape[1]
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
if encoder_hidden_states is not None:
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
dtype = hidden_states.dtype
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# attention
|
||||
if not attn.is_cross_attention:
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
else:
|
||||
query = hidden_states
|
||||
key = encoder_hidden_states
|
||||
value = encoder_hidden_states
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
|
||||
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||
|
||||
# RoPE需要 [B, H, S, D] 输入
|
||||
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
|
||||
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
|
||||
|
||||
# Apply query and key normalization if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if rotary_freqs_cis is not None:
|
||||
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||
if not attn.is_cross_attention:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||
|
||||
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
|
||||
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
|
||||
|
||||
if attention_mask is not None:
|
||||
# attention_mask: [B, S] -> [B, 1, S, 1]
|
||||
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
|
||||
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
|
||||
if not attn.is_cross_attention:
|
||||
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
|
||||
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
|
||||
|
||||
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
|
||||
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
|
||||
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
|
||||
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
|
||||
|
||||
query = self.kernel_func(query)
|
||||
key = self.kernel_func(key)
|
||||
|
||||
query, key, value = query.float(), key.float(), value.float()
|
||||
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
||||
|
||||
vk = torch.matmul(value, key)
|
||||
|
||||
hidden_states = torch.matmul(vk, query)
|
||||
|
||||
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.float()
|
||||
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : hidden_states_len],
|
||||
hidden_states[:, hidden_states_len:],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if encoder_hidden_states is not None and context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if torch.get_autocast_gpu_dtype() == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CustomerAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if rotary_freqs_cis is not None:
|
||||
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||
if not attn.is_cross_attention:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||
|
||||
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||
# attention_mask: N x S1
|
||||
# encoder_attention_mask: N x S2
|
||||
# cross attention 整合attention_mask和encoder_attention_mask
|
||||
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
|
||||
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
|
||||
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
|
||||
|
||||
elif not attn.is_cross_attention and attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
|
||||
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
|
||||
if isinstance(x, (list, tuple)):
|
||||
return list(x)
|
||||
return [x for _ in range(repeat_time)]
|
||||
|
||||
|
||||
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
|
||||
"""Return tuple with min_len by repeating element at idx_repeat."""
|
||||
# convert to list first
|
||||
x = val2list(x)
|
||||
|
||||
# repeat elements if necessary
|
||||
if len(x) > 0:
|
||||
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
||||
|
||||
return tuple(x)
|
||||
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
|
||||
if isinstance(kernel_size, tuple):
|
||||
return tuple([get_same_padding(ks) for ks in kernel_size])
|
||||
else:
|
||||
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
|
||||
return kernel_size // 2
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
padding: Union[int, None] = None,
|
||||
use_bias=False,
|
||||
norm=None,
|
||||
act=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if padding is None:
|
||||
padding = get_same_padding(kernel_size)
|
||||
padding *= dilation
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.padding = padding
|
||||
self.use_bias = use_bias
|
||||
|
||||
self.conv = operations.Conv1d(
|
||||
in_dim,
|
||||
out_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=use_bias,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if norm is not None:
|
||||
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.norm = None
|
||||
if act is not None:
|
||||
self.act = nn.SiLU(inplace=True)
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
if self.norm:
|
||||
x = self.norm(x)
|
||||
if self.act:
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_feature=None,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding: Union[int, None] = None,
|
||||
use_bias=False,
|
||||
norm=(None, None, None),
|
||||
act=("silu", "silu", None),
|
||||
dilation=1,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
out_feature = out_feature or in_features
|
||||
super().__init__()
|
||||
use_bias = val2tuple(use_bias, 3)
|
||||
norm = val2tuple(norm, 3)
|
||||
act = val2tuple(act, 3)
|
||||
|
||||
self.glu_act = nn.SiLU(inplace=False)
|
||||
self.inverted_conv = ConvLayer(
|
||||
in_features,
|
||||
hidden_features * 2,
|
||||
1,
|
||||
use_bias=use_bias[0],
|
||||
norm=norm[0],
|
||||
act=act[0],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.depth_conv = ConvLayer(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
groups=hidden_features * 2,
|
||||
padding=padding,
|
||||
use_bias=use_bias[1],
|
||||
norm=norm[1],
|
||||
act=None,
|
||||
dilation=dilation,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.point_conv = ConvLayer(
|
||||
hidden_features,
|
||||
out_feature,
|
||||
1,
|
||||
use_bias=use_bias[2],
|
||||
norm=norm[2],
|
||||
act=act[2],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.inverted_conv(x)
|
||||
x = self.depth_conv(x)
|
||||
|
||||
x, gate = torch.chunk(x, 2, dim=1)
|
||||
gate = self.glu_act(gate)
|
||||
x = x * gate
|
||||
|
||||
x = self.point_conv(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LinearTransformerBlock(nn.Module):
|
||||
"""
|
||||
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
use_adaln_single=True,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=None,
|
||||
context_pre_only=False,
|
||||
mlp_ratio=4.0,
|
||||
add_cross_attention=False,
|
||||
add_cross_attention_dim=None,
|
||||
qk_norm=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
qk_norm=qk_norm,
|
||||
processor=CustomLiteLAProcessor2_0(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.add_cross_attention = add_cross_attention
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
if add_cross_attention and add_cross_attention_dim is not None:
|
||||
self.cross_attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=add_cross_attention_dim,
|
||||
added_kv_proj_dim=add_cross_attention_dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
qk_norm=qk_norm,
|
||||
processor=CustomerAttnProcessor2_0(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
|
||||
|
||||
self.ff = GLUMBConv(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
use_bias=(True, True, False),
|
||||
norm=(None, None, None),
|
||||
act=("silu", "silu", None),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.use_adaln_single = use_adaln_single
|
||||
if use_adaln_single:
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: torch.FloatTensor = None,
|
||||
encoder_attention_mask: torch.FloatTensor = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
|
||||
# step 1: AdaLN single
|
||||
if self.use_adaln_single:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
if self.use_adaln_single:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
# step 2: attention
|
||||
if not self.add_cross_attention:
|
||||
attn_output, encoder_hidden_states = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
attn_output = gate_msa * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.add_cross_attention:
|
||||
attn_output = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# step 3: add norm
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
if self.use_adaln_single:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
# step 4: feed forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
if self.use_adaln_single:
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
return hidden_states
|
||||
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
385
comfy/ldm/ace/model.py
Normal file
385
comfy/ldm/ace/model.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
||||
|
||||
|
||||
def cross_norm(hidden_states, controlnet_input):
|
||||
# input N x T x c
|
||||
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
||||
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
||||
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
||||
return controlnet_input
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of Sana.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
def unpatchfy(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
width: int,
|
||||
):
|
||||
# 4 unpatchify
|
||||
new_height, new_width = 1, hidden_states.size(1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
||||
).contiguous()
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
||||
).contiguous()
|
||||
if width > new_width:
|
||||
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
||||
elif width < new_width:
|
||||
output = output[:, :, :, :width]
|
||||
return output
|
||||
|
||||
def forward(self, x, t, output_length):
|
||||
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
# unpatchify
|
||||
output = self.unpatchfy(x, output_length)
|
||||
return output
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=16,
|
||||
width=4096,
|
||||
patch_size=(16, 1),
|
||||
in_channels=8,
|
||||
embed_dim=1152,
|
||||
bias=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
patch_size_h, patch_size_w = patch_size
|
||||
self.early_conv_layers = nn.Sequential(
|
||||
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
|
||||
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
|
||||
)
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size_h, width // patch_size_w
|
||||
self.base_size = self.width
|
||||
|
||||
def forward(self, latent):
|
||||
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
||||
latent = self.early_conv_layers(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return latent
|
||||
|
||||
|
||||
class ACEStepTransformer2DModel(nn.Module):
|
||||
# _supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: Optional[int] = 8,
|
||||
num_layers: int = 28,
|
||||
inner_dim: int = 1536,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_channels: int = 8,
|
||||
max_position: int = 32768,
|
||||
rope_theta: float = 1000000.0,
|
||||
speaker_embedding_dim: int = 512,
|
||||
text_embedding_dim: int = 768,
|
||||
ssl_encoder_depths: List[int] = [9, 9],
|
||||
ssl_names: List[str] = ["mert", "m-hubert"],
|
||||
ssl_latent_dims: List[int] = [1024, 768],
|
||||
lyric_encoder_vocab_size: int = 6681,
|
||||
lyric_hidden_size: int = 1024,
|
||||
patch_size: List[int] = [16, 1],
|
||||
max_height: int = 16,
|
||||
max_width: int = 4096,
|
||||
audio_model=None,
|
||||
dtype=None, device=None, operations=None
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.out_channels = out_channels
|
||||
self.max_position = max_position
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||
dim=self.attention_head_dim,
|
||||
max_position_embeddings=self.max_position,
|
||||
base=self.rope_theta,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_layers = num_layers
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LinearTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
add_cross_attention=True,
|
||||
add_cross_attention_dim=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
|
||||
|
||||
# speaker
|
||||
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# genre
|
||||
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# lyric
|
||||
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
|
||||
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
|
||||
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
projector_dim = 2 * self.inner_dim
|
||||
|
||||
self.projectors = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
|
||||
) for ssl_dim in ssl_latent_dims
|
||||
])
|
||||
|
||||
self.proj_in = PatchEmbed(
|
||||
height=max_height,
|
||||
width=max_width,
|
||||
patch_size=patch_size,
|
||||
embed_dim=self.inner_dim,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_lyric_encoder(
|
||||
self,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
out_dtype=None,
|
||||
):
|
||||
# N x T x D
|
||||
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
|
||||
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
||||
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
||||
return prompt_prenet_out
|
||||
|
||||
def encode(
|
||||
self,
|
||||
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
lyrics_strength=1.0,
|
||||
):
|
||||
|
||||
bs = encoder_text_hidden_states.shape[0]
|
||||
device = encoder_text_hidden_states.device
|
||||
|
||||
# speaker embedding
|
||||
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
||||
|
||||
# genre embedding
|
||||
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
||||
|
||||
# lyric
|
||||
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
out_dtype=encoder_text_hidden_states.dtype,
|
||||
)
|
||||
|
||||
encoder_lyric_hidden_states *= lyrics_strength
|
||||
|
||||
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
||||
|
||||
encoder_hidden_mask = None
|
||||
if text_attention_mask is not None:
|
||||
speaker_mask = torch.ones(bs, 1, device=device)
|
||||
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
||||
|
||||
return encoder_hidden_states, encoder_hidden_mask
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_mask: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor],
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# controlnet logic
|
||||
if block_controlnet_hidden_states is not None:
|
||||
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
||||
hidden_states = hidden_states + control_condi * controlnet_scale
|
||||
|
||||
# inner_hidden_states = []
|
||||
|
||||
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
||||
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_hidden_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
hidden_states = x
|
||||
encoder_text_hidden_states = context
|
||||
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embeds,
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
lyrics_strength=lyrics_strength,
|
||||
)
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_mask=encoder_hidden_mask,
|
||||
timestep=timestep,
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
)
|
||||
|
||||
return output
|
||||
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
@@ -0,0 +1,644 @@
|
||||
# Rewritten from diffusers
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple, Union
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class RMSNorm(ops.RMSNorm):
|
||||
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
|
||||
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
if elementwise_affine:
|
||||
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.elementwise_affine:
|
||||
if self.bias is not None:
|
||||
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
|
||||
return x
|
||||
|
||||
|
||||
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
|
||||
if norm_type == "batch_norm":
|
||||
return nn.BatchNorm2d(num_features)
|
||||
elif norm_type == "group_norm":
|
||||
return ops.GroupNorm(num_groups, num_features)
|
||||
elif norm_type == "layer_norm":
|
||||
return ops.LayerNorm(num_features)
|
||||
elif norm_type == "rms_norm":
|
||||
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {norm_type}")
|
||||
|
||||
|
||||
def get_activation(activation_type):
|
||||
if activation_type == "relu":
|
||||
return nn.ReLU()
|
||||
elif activation_type == "relu6":
|
||||
return nn.ReLU6()
|
||||
elif activation_type == "silu":
|
||||
return nn.SiLU()
|
||||
elif activation_type == "leaky_relu":
|
||||
return nn.LeakyReLU(0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown activation type: {activation_type}")
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_type: str = "batch_norm",
|
||||
act_fn: str = "relu6",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
|
||||
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
|
||||
self.norm = get_normalization(norm_type, out_channels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = ops.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class SanaMultiscaleLinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: int = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: tuple = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult)
|
||||
if num_attention_heads is None
|
||||
else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, out_channels)
|
||||
|
||||
def apply_linear_attention(self, query, key, value):
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
|
||||
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
return hidden_states
|
||||
|
||||
def apply_quadratic_attention(self, query, key, value):
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states):
|
||||
height, width = hidden_states.shape[-2:]
|
||||
if height * width > self.attention_head_dim:
|
||||
use_linear_attention = True
|
||||
else:
|
||||
use_linear_attention = False
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, _, height, width = list(hidden_states.size())
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.movedim(1, -1)
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
hidden_states = torch.cat([query, key, value], dim=3)
|
||||
hidden_states = hidden_states.movedim(-1, 1)
|
||||
|
||||
multi_scale_qkv = [hidden_states]
|
||||
for block in self.to_qkv_multiscale:
|
||||
multi_scale_qkv.append(block(hidden_states))
|
||||
|
||||
hidden_states = torch.cat(multi_scale_qkv, dim=1)
|
||||
|
||||
if use_linear_attention:
|
||||
# for linear attention upcast hidden_states to float32
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
|
||||
|
||||
query, key, value = hidden_states.chunk(3, dim=2)
|
||||
query = self.nonlinearity(query)
|
||||
key = self.nonlinearity(key)
|
||||
|
||||
if use_linear_attention:
|
||||
hidden_states = self.apply_linear_attention(query, key, value)
|
||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||
else:
|
||||
hidden_states = self.apply_quadratic_attention(query, key, value)
|
||||
|
||||
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
|
||||
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EfficientViTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
mult: float = 1.0,
|
||||
attention_head_dim: int = 32,
|
||||
qkv_multiscales: tuple = (5,),
|
||||
norm_type: str = "batch_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = SanaMultiscaleLinearAttention(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
mult=mult,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
kernel_sizes=qkv_multiscales,
|
||||
residual_connection=True,
|
||||
)
|
||||
|
||||
self.conv_out = GLUMBConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
norm_type="rms_norm",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.attn(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: str = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_block(
|
||||
block_type: str,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: tuple = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
qkv_multiscales=qkv_mutliscales
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Block with {block_type=} is not supported.")
|
||||
|
||||
return block
|
||||
|
||||
|
||||
class DCDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.downsample = downsample
|
||||
self.factor = 2
|
||||
self.stride = 1 if downsample else 2
|
||||
self.group_size = in_channels * self.factor**2 // out_channels
|
||||
self.shortcut = shortcut
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if downsample:
|
||||
assert out_channels % out_ratio == 0
|
||||
out_channels = out_channels // out_ratio
|
||||
|
||||
self.conv = ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(hidden_states)
|
||||
if self.downsample:
|
||||
x = F.pixel_unshuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = F.pixel_unshuffle(hidden_states, self.factor)
|
||||
y = y.unflatten(1, (-1, self.group_size))
|
||||
y = y.mean(dim=2)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DCUpBlock2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
interpolate: bool = False,
|
||||
shortcut: bool = True,
|
||||
interpolation_mode: str = "nearest",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.interpolate = interpolate
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.shortcut = shortcut
|
||||
self.factor = 2
|
||||
self.repeats = out_channels * self.factor**2 // in_channels
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if not interpolate:
|
||||
out_channels = out_channels * out_ratio
|
||||
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.interpolate:
|
||||
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.conv(hidden_states)
|
||||
x = F.pixel_shuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
|
||||
y = F.pixel_shuffle(y, self.factor)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: str or tuple = "ResBlock",
|
||||
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_in = ops.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
else:
|
||||
self.conv_in = DCDownBlock2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=False,
|
||||
)
|
||||
|
||||
down_blocks = []
|
||||
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
|
||||
down_block_list = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
downsample_block = DCDownBlock2d(
|
||||
in_channels=out_channel,
|
||||
out_channels=block_out_channels[i + 1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=True,
|
||||
)
|
||||
down_block_list.append(downsample_block)
|
||||
|
||||
down_blocks.append(nn.Sequential(*down_block_list))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
|
||||
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
|
||||
|
||||
self.out_shortcut = out_shortcut
|
||||
if out_shortcut:
|
||||
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
if self.out_shortcut:
|
||||
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
||||
x = x.mean(dim=2)
|
||||
hidden_states = self.conv_out(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: str or tuple = "ResBlock",
|
||||
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: str or tuple = "rms_norm",
|
||||
act_fn: str or tuple = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
if isinstance(norm_type, str):
|
||||
norm_type = (norm_type,) * num_blocks
|
||||
if isinstance(act_fn, str):
|
||||
act_fn = (act_fn,) * num_blocks
|
||||
|
||||
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
|
||||
|
||||
self.in_shortcut = in_shortcut
|
||||
if in_shortcut:
|
||||
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
|
||||
|
||||
up_blocks = []
|
||||
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
|
||||
up_block_list = []
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
upsample_block = DCUpBlock2d(
|
||||
block_out_channels[i + 1],
|
||||
out_channel,
|
||||
interpolate=upsample_block_type == "interpolate",
|
||||
shortcut=True,
|
||||
)
|
||||
up_block_list.append(upsample_block)
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
up_blocks.insert(0, nn.Sequential(*up_block_list))
|
||||
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
|
||||
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
|
||||
|
||||
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
|
||||
self.conv_act = nn.ReLU()
|
||||
self.conv_out = None
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
|
||||
else:
|
||||
self.conv_out = DCUpBlock2d(
|
||||
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.in_shortcut:
|
||||
x = hidden_states.repeat_interleave(
|
||||
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
|
||||
)
|
||||
hidden_states = self.conv_in(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for up_block in reversed(self.up_blocks):
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderDC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 2,
|
||||
latent_channels: int = 8,
|
||||
attention_head_dim: int = 32,
|
||||
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||
upsample_block_type: str = "interpolate",
|
||||
downsample_block_type: str = "Conv",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
scaling_factor: float = 0.41407,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=encoder_block_types,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
layers_per_block=encoder_layers_per_block,
|
||||
qkv_multiscales=encoder_qkv_multiscales,
|
||||
downsample_block_type=downsample_block_type,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=decoder_block_types,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
qkv_multiscales=decoder_qkv_multiscales,
|
||||
norm_type=decoder_norm_types,
|
||||
act_fn=decoder_act_fns,
|
||||
upsample_block_type=upsample_block_type,
|
||||
)
|
||||
|
||||
self.scaling_factor = scaling_factor
|
||||
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Internal encoding function."""
|
||||
encoded = self.encoder(x)
|
||||
return encoded * self.scaling_factor
|
||||
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
# Scale the latents back
|
||||
z = z / self.scaling_factor
|
||||
decoded = self.decoder(z)
|
||||
return decoded
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
z = self.encode(x)
|
||||
return self.decode(z)
|
||||
|
||||
109
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
109
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
|
||||
import torch
|
||||
from .autoencoder_dc import AutoencoderDC
|
||||
import logging
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from .music_vocoder import ADaMoSHiFiGANV1
|
||||
|
||||
|
||||
class MusicDCAE(torch.nn.Module):
|
||||
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
|
||||
super(MusicDCAE, self).__init__()
|
||||
|
||||
self.dcae = AutoencoderDC(**dcae_config)
|
||||
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
|
||||
|
||||
if source_sample_rate is None:
|
||||
self.source_sample_rate = 48000
|
||||
else:
|
||||
self.source_sample_rate = source_sample_rate
|
||||
|
||||
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
self.min_mel_value = -11.0
|
||||
self.max_mel_value = 3.0
|
||||
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
|
||||
self.mel_chunk_size = 1024
|
||||
self.time_dimention_multiple = 8
|
||||
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
|
||||
self.scale_factor = 0.1786
|
||||
self.shift_factor = -1.9091
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
return audio, sr
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
image = self.vocoder.mel_transform(audios[i])
|
||||
mels.append(image)
|
||||
mels = torch.stack(mels)
|
||||
return mels
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audios, audio_lengths=None, sr=None):
|
||||
if audio_lengths is None:
|
||||
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
|
||||
audio_lengths = audio_lengths.to(audios.device)
|
||||
|
||||
if sr is None:
|
||||
sr = self.source_sample_rate
|
||||
|
||||
if sr != 44100:
|
||||
audios = torchaudio.functional.resample(audios, sr, 44100)
|
||||
|
||||
max_audio_len = audios.shape[-1]
|
||||
if max_audio_len % (8 * 512) != 0:
|
||||
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
|
||||
|
||||
mels = self.forward_mel(audios)
|
||||
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
||||
mels = self.transform(mels)
|
||||
latents = []
|
||||
for mel in mels:
|
||||
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||
latents.append(latent)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||
latents = (latents - self.shift_factor) * self.scale_factor
|
||||
return latents
|
||||
# return latents, latent_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, audio_lengths=None, sr=None):
|
||||
latents = latents / self.scale_factor + self.shift_factor
|
||||
|
||||
pred_wavs = []
|
||||
|
||||
for latent in latents:
|
||||
mels = self.dcae.decoder(latent.unsqueeze(0))
|
||||
mels = mels * 0.5 + 0.5
|
||||
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
||||
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||
|
||||
if sr is not None:
|
||||
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||
# wav = resampler(wav)
|
||||
else:
|
||||
sr = 44100
|
||||
pred_wavs.append(wav)
|
||||
|
||||
if audio_lengths is not None:
|
||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||
return torch.stack(pred_wavs)
|
||||
# return sr, pred_wavs
|
||||
|
||||
def forward(self, audios, audio_lengths=None, sr=None):
|
||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
|
||||
return sr, pred_wavs, latents, latent_lengths
|
||||
113
comfy/ldm/ace/vae/music_log_mel.py
Executable file
113
comfy/ldm/ace/vae/music_log_mel.py
Executable file
@@ -0,0 +1,113 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import logging
|
||||
try:
|
||||
from torchaudio.transforms import MelScale
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
class LinearSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
center=False,
|
||||
mode="pow2_sqrt",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.mode = mode
|
||||
|
||||
self.register_buffer("window", torch.hann_window(win_length))
|
||||
|
||||
def forward(self, y: Tensor) -> Tensor:
|
||||
if y.ndim == 3:
|
||||
y = y.squeeze(1)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(
|
||||
(self.win_length - self.hop_length) // 2,
|
||||
(self.win_length - self.hop_length + 1) // 2,
|
||||
),
|
||||
mode="reflect",
|
||||
).squeeze(1)
|
||||
dtype = y.dtype
|
||||
spec = torch.stft(
|
||||
y.float(),
|
||||
self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
|
||||
center=self.center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.view_as_real(spec)
|
||||
|
||||
if self.mode == "pow2_sqrt":
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = spec.to(dtype)
|
||||
return spec
|
||||
|
||||
|
||||
class LogMelSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
n_mels=128,
|
||||
center=False,
|
||||
f_min=0.0,
|
||||
f_max=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.n_mels = n_mels
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max or sample_rate // 2
|
||||
|
||||
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
||||
self.mel_scale = MelScale(
|
||||
self.n_mels,
|
||||
self.sample_rate,
|
||||
self.f_min,
|
||||
self.f_max,
|
||||
self.n_fft // 2 + 1,
|
||||
"slaney",
|
||||
"slaney",
|
||||
)
|
||||
|
||||
def compress(self, x: Tensor) -> Tensor:
|
||||
return torch.log(torch.clamp(x, min=1e-5))
|
||||
|
||||
def decompress(self, x: Tensor) -> Tensor:
|
||||
return torch.exp(x)
|
||||
|
||||
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
||||
linear = self.spectrogram(x)
|
||||
x = self.mel_scale(linear)
|
||||
x = self.compress(x)
|
||||
# print(x.shape)
|
||||
if return_linear:
|
||||
return x, self.compress(linear)
|
||||
|
||||
return x
|
||||
538
comfy/ldm/ace/vae/music_vocoder.py
Executable file
538
comfy/ldm/ace/vae/music_vocoder.py
Executable file
@@ -0,0 +1,538 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from functools import partial
|
||||
from math import prod
|
||||
from typing import Callable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||
|
||||
from .music_log_mel import LogMelSpectrogram
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
def drop_path(
|
||||
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
||||
):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (
|
||||
x.ndim - 1
|
||||
) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
||||
|
||||
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||
with shape (batch_size, channels, height, width).
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in ["channels_last", "channels_first"]:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
|
||||
)
|
||||
elif self.data_format == "channels_first":
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
r"""ConvNeXt Block. There are two equivalent implementations:
|
||||
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
||||
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
||||
We use (2) as we find it slightly faster in PyTorch
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
||||
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
||||
dilation (int): Dilation for depthwise conv. Default: 1.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
mlp_ratio: float = 4.0,
|
||||
kernel_size: int = 7,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dwconv = ops.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=int(dilation * (kernel_size - 1) / 2),
|
||||
groups=dim,
|
||||
) # depthwise conv
|
||||
self.norm = LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = ops.Linear(
|
||||
dim, int(mlp_ratio * dim)
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(torch.empty((dim)), requires_grad=False)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x, apply_residual: bool = True):
|
||||
input = x
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
|
||||
if self.gamma is not None:
|
||||
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
|
||||
|
||||
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
||||
x = self.drop_path(x)
|
||||
|
||||
if apply_residual:
|
||||
x = input + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ParallelConvNeXtBlock(nn.Module):
|
||||
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
|
||||
for kernel_size in kernel_sizes
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[block(x, apply_residual=False) for block in self.blocks] + [x],
|
||||
dim=1,
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
class ConvNeXtEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels=3,
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[96, 192, 384, 768],
|
||||
drop_path_rate=0.0,
|
||||
layer_scale_init_value=1e-6,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
):
|
||||
super().__init__()
|
||||
assert len(depths) == len(dims)
|
||||
|
||||
self.channel_layers = nn.ModuleList()
|
||||
stem = nn.Sequential(
|
||||
ops.Conv1d(
|
||||
input_channels,
|
||||
dims[0],
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
padding_mode="replicate",
|
||||
),
|
||||
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
||||
)
|
||||
self.channel_layers.append(stem)
|
||||
|
||||
for i in range(len(depths) - 1):
|
||||
mid_layer = nn.Sequential(
|
||||
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
||||
)
|
||||
self.channel_layers.append(mid_layer)
|
||||
|
||||
block_fn = (
|
||||
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
|
||||
if len(kernel_sizes) == 1
|
||||
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
|
||||
)
|
||||
|
||||
self.stages = nn.ModuleList()
|
||||
drop_path_rates = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
]
|
||||
|
||||
cur = 0
|
||||
for i in range(len(depths)):
|
||||
stage = nn.Sequential(
|
||||
*[
|
||||
block_fn(
|
||||
dim=dims[i],
|
||||
drop_path=drop_path_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
for j in range(depths[i])
|
||||
]
|
||||
)
|
||||
self.stages.append(stage)
|
||||
cur += depths[i]
|
||||
|
||||
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for channel_layer, stage in zip(self.channel_layers, self.stages):
|
||||
x = channel_layer(x)
|
||||
x = stage(x)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return (kernel_size * dilation - dilation) // 2
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.silu(x)
|
||||
xt = c1(xt)
|
||||
xt = F.silu(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for conv in self.convs1:
|
||||
remove_weight_norm(conv)
|
||||
for conv in self.convs2:
|
||||
remove_weight_norm(conv)
|
||||
|
||||
|
||||
class HiFiGANGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hop_length: int = 512,
|
||||
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 128,
|
||||
upsample_initial_channel: int = 512,
|
||||
use_template: bool = True,
|
||||
pre_conv_kernel_size: int = 7,
|
||||
post_conv_kernel_size: int = 7,
|
||||
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
prod(upsample_rates) == hop_length
|
||||
), f"hop_length must be {prod(upsample_rates)}"
|
||||
|
||||
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
num_mels,
|
||||
upsample_initial_channel,
|
||||
pre_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(pre_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.use_template = use_template
|
||||
self.ups = nn.ModuleList()
|
||||
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not use_template:
|
||||
continue
|
||||
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||
self.noise_convs.append(
|
||||
ops.Conv1d(
|
||||
1,
|
||||
c_cur,
|
||||
kernel_size=stride_f0 * 2,
|
||||
stride=stride_f0,
|
||||
padding=stride_f0 // 2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
self.resblocks.append(ResBlock1(ch, k, d))
|
||||
|
||||
self.activation_post = post_activation()
|
||||
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
ch,
|
||||
1,
|
||||
post_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(post_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, template=None):
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if self.use_template:
|
||||
x = x + self.noise_convs[i](template)
|
||||
|
||||
xs = None
|
||||
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for up in self.ups:
|
||||
remove_weight_norm(up)
|
||||
for block in self.resblocks:
|
||||
block.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class ADaMoSHiFiGANV1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 128,
|
||||
depths: List[int] = [3, 3, 9, 3],
|
||||
dims: List[int] = [128, 256, 384, 512],
|
||||
drop_path_rate: float = 0.0,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 512,
|
||||
upsample_initial_channel: int = 1024,
|
||||
use_template: bool = False,
|
||||
pre_conv_kernel_size: int = 13,
|
||||
post_conv_kernel_size: int = 13,
|
||||
sampling_rate: int = 44100,
|
||||
n_fft: int = 2048,
|
||||
win_length: int = 2048,
|
||||
hop_length: int = 512,
|
||||
f_min: int = 40,
|
||||
f_max: int = 16000,
|
||||
n_mels: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = ConvNeXtEncoder(
|
||||
input_channels=input_channels,
|
||||
depths=depths,
|
||||
dims=dims,
|
||||
drop_path_rate=drop_path_rate,
|
||||
kernel_sizes=kernel_sizes,
|
||||
)
|
||||
|
||||
self.head = HiFiGANGenerator(
|
||||
hop_length=hop_length,
|
||||
upsample_rates=upsample_rates,
|
||||
upsample_kernel_sizes=upsample_kernel_sizes,
|
||||
resblock_kernel_sizes=resblock_kernel_sizes,
|
||||
resblock_dilation_sizes=resblock_dilation_sizes,
|
||||
num_mels=num_mels,
|
||||
upsample_initial_channel=upsample_initial_channel,
|
||||
use_template=use_template,
|
||||
pre_conv_kernel_size=pre_conv_kernel_size,
|
||||
post_conv_kernel_size=post_conv_kernel_size,
|
||||
)
|
||||
self.sampling_rate = sampling_rate
|
||||
self.mel_transform = LogMelSpectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
f_min=f_min,
|
||||
f_max=f_max,
|
||||
n_mels=n_mels,
|
||||
)
|
||||
self.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x):
|
||||
return self.mel_transform(x)
|
||||
|
||||
def forward(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Literal, Dict, Any
|
||||
from typing import Literal
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
@@ -97,7 +91,7 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False,
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
act = Activation1d(act) # noqa: F821 Activation1d is not defined
|
||||
|
||||
return act
|
||||
|
||||
|
||||
@@ -158,7 +158,6 @@ class RotaryEmbedding(nn.Module):
|
||||
def forward(self, t):
|
||||
# device = self.inv_freq.device
|
||||
device = t.device
|
||||
dtype = t.dtype
|
||||
|
||||
# t = t.to(torch.float32)
|
||||
|
||||
@@ -170,7 +169,7 @@ class RotaryEmbedding(nn.Module):
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
|
||||
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
@@ -229,9 +228,9 @@ class FeedForward(nn.Module):
|
||||
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
linear_in = nn.Sequential(
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
activation
|
||||
)
|
||||
|
||||
@@ -246,9 +245,9 @@ class FeedForward(nn.Module):
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
linear_in,
|
||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
linear_out,
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -346,18 +345,13 @@ class Attention(nn.Module):
|
||||
|
||||
# determine masking
|
||||
masks = []
|
||||
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
||||
|
||||
if input_mask is not None:
|
||||
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||
masks.append(~input_mask)
|
||||
|
||||
# Other masks will be added here later
|
||||
|
||||
if len(masks) > 0:
|
||||
final_attn_mask = ~or_reduce(masks)
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
n = q.shape[-2]
|
||||
|
||||
causal = self.causal if causal is None else causal
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, einsum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from torch import Tensor
|
||||
from typing import List, Union
|
||||
from einops import rearrange
|
||||
import math
|
||||
import comfy.ops
|
||||
|
||||
@@ -147,7 +147,6 @@ class DoubleAttention(nn.Module):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
seqlen = seqlen1 + seqlen2
|
||||
|
||||
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
@@ -382,7 +381,6 @@ class MMDiT(nn.Module):
|
||||
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
||||
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
||||
self.h_max, self.w_max = target_dim
|
||||
print("PE extended to", target_dim)
|
||||
|
||||
def pe_selection_index_based_on_dim(self, h, w):
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from .common import LayerNorm2d_op
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
ops.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
ops.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
ops.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
@@ -171,16 +175,16 @@ class StageA(nn.Module):
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
@@ -191,7 +195,7 @@ class StageA(nn.Module):
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
@@ -199,11 +203,11 @@ class StageA(nn.Module):
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
|
||||
@@ -138,7 +138,7 @@ class StageB(nn.Module):
|
||||
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
@@ -148,7 +148,7 @@ class StageB(nn.Module):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
|
||||
@@ -142,7 +142,7 @@ class StageC(nn.Module):
|
||||
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
@@ -152,7 +152,7 @@ class StageC(nn.Module):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
|
||||
@@ -19,6 +19,9 @@ import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
|
||||
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
181
comfy/ldm/chroma/layers.py
Normal file
181
comfy/ldm/chroma/layers.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@classmethod
|
||||
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||
return cls(
|
||||
shift=tensor[:, offset : offset + 1, :],
|
||||
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
||||
x = self.linear(x)
|
||||
return x
|
||||
271
comfy/ldm/chroma/model.py
Normal file
271
comfy/ldm/chroma/model.py
Normal file
@@ -0,0 +1,271 @@
|
||||
#Original code can be found on: https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
in_dim: int
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
|
||||
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = ChromaParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.in_dim = params.in_dim
|
||||
self.out_dim = params.out_dim
|
||||
self.hidden_dim = params.hidden_dim
|
||||
self.n_layers = params.n_layers
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
# set as nn identity for now, will overwrite it later.
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
in_dim=self.in_dim,
|
||||
hidden_dim=self.hidden_dim,
|
||||
out_dim=self.out_dim,
|
||||
n_layers=self.n_layers,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
|
||||
# This function slices up the modulations tensor which has the following layout:
|
||||
# single : num_single_blocks * 3 elements
|
||||
# double_img : num_double_blocks * 6 elements
|
||||
# double_txt : num_double_blocks * 6 elements
|
||||
# final : 2 elements
|
||||
if block_type == "final":
|
||||
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
|
||||
single_block_count = self.params.depth_single_blocks
|
||||
double_block_count = self.params.depth
|
||||
offset = 3 * idx
|
||||
if block_type == "single":
|
||||
return ChromaModulationOut.from_offset(tensor, offset)
|
||||
# Double block modulations are 6 elements so we double 3 * idx.
|
||||
offset *= 2
|
||||
if block_type in {"double_img", "double_txt"}:
|
||||
# Advance past the single block modulations.
|
||||
offset += 3 * single_block_count
|
||||
if block_type == "double_txt":
|
||||
# Advance past the double block img modulations.
|
||||
offset += 6 * double_block_count
|
||||
return (
|
||||
ChromaModulationOut.from_offset(tensor, offset),
|
||||
ChromaModulationOut.from_offset(tensor, offset + 3),
|
||||
)
|
||||
raise ValueError("Bad block_type")
|
||||
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
# distilled vector guidance
|
||||
mod_index_length = 344
|
||||
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
|
||||
# guidance = guidance *
|
||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
||||
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
|
||||
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
@@ -1,27 +1,16 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
import comfy.rmsnorm
|
||||
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
pad = ()
|
||||
for i in range(img.ndim - 2):
|
||||
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||
|
||||
|
||||
rms_norm = comfy.rmsnorm.rms_norm
|
||||
|
||||
807
comfy/ldm/cosmos/blocks.py
Normal file
807
comfy/ldm/cosmos/blocks.py
Normal file
@@ -0,0 +1,807 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
elif name == "R":
|
||||
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||
else:
|
||||
raise ValueError(f"Normalization {name} not found")
|
||||
|
||||
|
||||
class BaseAttentionOp(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Generalized attention impl.
|
||||
|
||||
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
|
||||
If `context_dim` is None, self-attention is assumed.
|
||||
|
||||
Parameters:
|
||||
query_dim (int): Dimension of each query vector.
|
||||
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
|
||||
heads (int, optional): Number of attention heads. Defaults to 8.
|
||||
dim_head (int, optional): Dimension of each head. Defaults to 64.
|
||||
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
|
||||
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
|
||||
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
|
||||
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
|
||||
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
|
||||
Defaults to "SSI".
|
||||
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
|
||||
Defaults to 'per_head'. Only support 'per_head'.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
|
||||
>>> query = torch.randn(10, 128) # Batch size of 10
|
||||
>>> context = torch.randn(10, 256) # Batch size of 10
|
||||
>>> output = attn(query, context) # Perform the attention operation
|
||||
|
||||
Note:
|
||||
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_op: Optional[BaseAttentionOp] = None,
|
||||
qkv_bias: bool = False,
|
||||
out_bias: bool = False,
|
||||
qkv_norm: str = "SSI",
|
||||
qkv_norm_mode: str = "per_head",
|
||||
backend: str = "transformer_engine",
|
||||
qkv_format: str = "bshd",
|
||||
weight_args={},
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.is_selfattn = context_dim is None # self attention
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.qkv_norm_mode = qkv_norm_mode
|
||||
self.qkv_format = qkv_format
|
||||
|
||||
if self.qkv_norm_mode == "per_head":
|
||||
norm_dim = dim_head
|
||||
else:
|
||||
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||
|
||||
self.backend = backend
|
||||
|
||||
self.to_q = nn.Sequential(
|
||||
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
self.to_k = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
self.to_v = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def cal_qkv(
|
||||
self, x, context=None, mask=None, rope_emb=None, **kwargs
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
del kwargs
|
||||
|
||||
|
||||
"""
|
||||
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
|
||||
Before 07/24/2024, these modules normalize across all heads.
|
||||
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
|
||||
we support to normalize per head.
|
||||
To keep the checkpoint copatibility with the previous code,
|
||||
we keep the nn.Sequential but call the projection and the normalization layers separately.
|
||||
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
|
||||
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
|
||||
"""
|
||||
if self.qkv_norm_mode == "per_head":
|
||||
q = self.to_q[0](x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k[0](context)
|
||||
v = self.to_v[0](context)
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||
|
||||
q = self.to_q[1](q)
|
||||
k = self.to_k[1](k)
|
||||
v = self.to_v[1](v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
# apply_rotary_pos_emb inlined
|
||||
q_shape = q.shape
|
||||
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
||||
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
||||
|
||||
# apply_rotary_pos_emb inlined
|
||||
k_shape = k.shape
|
||||
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
||||
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
rope_emb=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
Transformer FFN with optional gating
|
||||
|
||||
Parameters:
|
||||
d_model (int): Dimensionality of input features.
|
||||
d_ff (int): Dimensionality of the hidden layer.
|
||||
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
|
||||
activation (callable, optional): The activation function applied after the first linear layer.
|
||||
Defaults to nn.ReLU().
|
||||
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
|
||||
Defaults to False.
|
||||
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
|
||||
|
||||
Example:
|
||||
>>> ff = FeedForward(d_model=512, d_ff=2048)
|
||||
>>> x = torch.randn(64, 10, 512) # Example input tensor
|
||||
>>> output = ff(x)
|
||||
>>> print(output.shape) # Expected shape: (64, 10, 512)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
d_ff: int,
|
||||
dropout: float = 0.1,
|
||||
activation=nn.ReLU(),
|
||||
is_gated: bool = False,
|
||||
bias: bool = False,
|
||||
weight_args={},
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
|
||||
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = activation
|
||||
self.is_gated = is_gated
|
||||
if is_gated:
|
||||
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
g = self.activation(self.layer1(x))
|
||||
if self.is_gated:
|
||||
x = g * self.linear_gate(x)
|
||||
else:
|
||||
x = g
|
||||
assert self.dropout.p == 0.0, "we skip dropout"
|
||||
return self.layer2(x)
|
||||
|
||||
|
||||
class GPT2FeedForward(FeedForward):
|
||||
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
|
||||
super().__init__(
|
||||
d_model=d_model,
|
||||
d_ff=d_ff,
|
||||
dropout=dropout,
|
||||
activation=nn.GELU(),
|
||||
is_gated=False,
|
||||
bias=bias,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
assert self.dropout.p == 0.0, "we skip dropout"
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.activation(x)
|
||||
x = self.layer2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, timesteps):
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - 0.0)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
sin_emb = torch.sin(emb)
|
||||
cos_emb = torch.cos(emb)
|
||||
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||
)
|
||||
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
|
||||
self.activation = nn.SiLU()
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if use_adaln_lora:
|
||||
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
|
||||
else:
|
||||
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear_1(sample)
|
||||
emb = self.activation(emb)
|
||||
emb = self.linear_2(emb)
|
||||
|
||||
if self.use_adaln_lora:
|
||||
adaln_lora_B_3D = emb
|
||||
emb_B_D = sample
|
||||
else:
|
||||
emb_B_D = emb
|
||||
adaln_lora_B_3D = None
|
||||
|
||||
return emb_B_D, adaln_lora_B_3D
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
"""
|
||||
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
|
||||
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
|
||||
|
||||
[B] -> [B, D]
|
||||
|
||||
Parameters:
|
||||
num_channels (int): The number of Fourier features to generate.
|
||||
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
|
||||
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
|
||||
the variance of the features. Defaults to False.
|
||||
|
||||
Example:
|
||||
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
|
||||
>>> x = torch.randn(10, 256) # Example input tensor
|
||||
>>> output = layer(x)
|
||||
>>> print(output.shape) # Expected shape: (10, 256)
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels, bandwidth=1, normalize=False):
|
||||
super().__init__()
|
||||
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||
self.gain = np.sqrt(2) if normalize else 1
|
||||
|
||||
def forward(self, x, gain: float = 1.0):
|
||||
"""
|
||||
Apply the Fourier feature transformation to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed tensor, with Fourier features applied.
|
||||
"""
|
||||
in_dtype = x.dtype
|
||||
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||
and embedding each patch into a vector of size `out_channels`.
|
||||
|
||||
Parameters:
|
||||
- spatial_patch_size (int): The size of each spatial patch.
|
||||
- temporal_patch_size (int): The size of each temporal patch.
|
||||
- in_channels (int): Number of input channels. Default: 3.
|
||||
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spatial_patch_size,
|
||||
temporal_patch_size,
|
||||
in_channels=3,
|
||||
out_channels=768,
|
||||
bias=True,
|
||||
weight_args={},
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
Rearrange(
|
||||
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||
r=temporal_patch_size,
|
||||
m=spatial_patch_size,
|
||||
n=spatial_patch_size,
|
||||
),
|
||||
operations.Linear(
|
||||
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
|
||||
),
|
||||
)
|
||||
self.out = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the PatchEmbed module.
|
||||
|
||||
Parameters:
|
||||
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||
B is the batch size,
|
||||
C is the number of channels,
|
||||
T is the temporal dimension,
|
||||
H is the height, and
|
||||
W is the width of the input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||
"""
|
||||
assert x.dim() == 5
|
||||
_, _, T, H, W = x.shape
|
||||
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||
assert T % self.temporal_patch_size == 0
|
||||
x = self.proj(x)
|
||||
return self.out(x)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of video DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
spatial_patch_size,
|
||||
temporal_patch_size,
|
||||
out_channels,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
weight_args={},
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.n_adaln_chunks = 2
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if use_adaln_lora:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
|
||||
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
|
||||
)
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_BT_HW_D,
|
||||
emb_B_D,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_3D is not None
|
||||
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
|
||||
2, dim=1
|
||||
)
|
||||
else:
|
||||
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
|
||||
|
||||
B = emb_B_D.shape[0]
|
||||
T = x_BT_HW_D.shape[0] // B
|
||||
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
|
||||
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
|
||||
|
||||
x_BT_HW_D = self.linear(x_BT_HW_D)
|
||||
return x_BT_HW_D
|
||||
|
||||
|
||||
class VideoAttn(nn.Module):
|
||||
"""
|
||||
Implements video attention with optional cross-attention capabilities.
|
||||
|
||||
This module processes video features while maintaining their spatio-temporal structure. It can perform
|
||||
self-attention within the video features or cross-attention with external context features.
|
||||
|
||||
Parameters:
|
||||
x_dim (int): Dimension of input feature vectors
|
||||
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
|
||||
num_heads (int): Number of attention heads
|
||||
bias (bool): Whether to include bias in attention projections. Default: False
|
||||
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
|
||||
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
|
||||
|
||||
Input shape:
|
||||
- x: (T, H, W, B, D) video features
|
||||
- context (optional): (M, B, D) context features for cross-attention
|
||||
where:
|
||||
T: temporal dimension
|
||||
H: height
|
||||
W: width
|
||||
B: batch size
|
||||
D: feature dimension
|
||||
M: context sequence length
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_dim: int,
|
||||
context_dim: Optional[int],
|
||||
num_heads: int,
|
||||
bias: bool = False,
|
||||
qkv_norm_mode: str = "per_head",
|
||||
x_format: str = "BTHWD",
|
||||
weight_args={},
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.x_format = x_format
|
||||
|
||||
self.attn = Attention(
|
||||
x_dim,
|
||||
context_dim,
|
||||
num_heads,
|
||||
x_dim // num_heads,
|
||||
qkv_bias=bias,
|
||||
qkv_norm="RRI",
|
||||
out_bias=bias,
|
||||
qkv_norm_mode=qkv_norm_mode,
|
||||
qkv_format="sbhd",
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for video attention.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
|
||||
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
|
||||
where M is the sequence length of the context.
|
||||
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
|
||||
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor with applied attention, maintaining the input shape.
|
||||
"""
|
||||
|
||||
x_T_H_W_B_D = x
|
||||
context_M_B_D = context
|
||||
T, H, W, B, D = x_T_H_W_B_D.shape
|
||||
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
|
||||
x_THW_B_D = self.attn(
|
||||
x_THW_B_D,
|
||||
context_M_B_D,
|
||||
crossattn_mask,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
)
|
||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||
return x_T_H_W_B_D
|
||||
|
||||
|
||||
def adaln_norm_state(norm_state, x, scale, shift):
|
||||
normalized = norm_state(x)
|
||||
return normalized * (1 + scale) + shift
|
||||
|
||||
|
||||
class DITBuildingBlock(nn.Module):
|
||||
"""
|
||||
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
|
||||
attention and MLP operations with adaptive layer normalization.
|
||||
|
||||
Parameters:
|
||||
block_type (str): Type of block - one of:
|
||||
- "cross_attn"/"ca": Cross-attention
|
||||
- "full_attn"/"fa": Full self-attention
|
||||
- "mlp"/"ff": MLP/feedforward block
|
||||
x_dim (int): Dimension of input features
|
||||
context_dim (Optional[int]): Dimension of context features for cross-attention
|
||||
num_heads (int): Number of attention heads
|
||||
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||
bias (bool): Whether to use bias in layers. Default: False
|
||||
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
|
||||
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
|
||||
x_format (str): Input tensor format. Default: "BTHWD"
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_type: str,
|
||||
x_dim: int,
|
||||
context_dim: Optional[int],
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
bias: bool = False,
|
||||
mlp_dropout: float = 0.0,
|
||||
qkv_norm_mode: str = "per_head",
|
||||
x_format: str = "BTHWD",
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
weight_args={},
|
||||
operations=None
|
||||
) -> None:
|
||||
block_type = block_type.lower()
|
||||
|
||||
super().__init__()
|
||||
self.x_format = x_format
|
||||
if block_type in ["cross_attn", "ca"]:
|
||||
self.block = VideoAttn(
|
||||
x_dim,
|
||||
context_dim,
|
||||
num_heads,
|
||||
bias=bias,
|
||||
qkv_norm_mode=qkv_norm_mode,
|
||||
x_format=self.x_format,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
elif block_type in ["full_attn", "fa"]:
|
||||
self.block = VideoAttn(
|
||||
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
|
||||
)
|
||||
elif block_type in ["mlp", "ff"]:
|
||||
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {block_type}")
|
||||
|
||||
self.block_type = block_type
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
|
||||
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.n_adaln_chunks = 3
|
||||
if use_adaln_lora:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
|
||||
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
|
||||
)
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb_B_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
|
||||
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
|
||||
crossattn_emb (Tensor): Tensor for cross-attention blocks.
|
||||
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
|
||||
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor after processing through the configured block and adaptive normalization.
|
||||
"""
|
||||
if self.use_adaln_lora:
|
||||
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
|
||||
self.n_adaln_chunks, dim=1
|
||||
)
|
||||
else:
|
||||
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
|
||||
|
||||
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
|
||||
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||
)
|
||||
|
||||
if self.block_type in ["mlp", "ff"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
)
|
||||
elif self.block_type in ["full_attn", "fa"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=None,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
)
|
||||
elif self.block_type in ["cross_attn", "ca"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GeneralDITTransformerBlock(nn.Module):
|
||||
"""
|
||||
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
|
||||
Each block in the sequence is specified by a block configuration string.
|
||||
|
||||
Parameters:
|
||||
x_dim (int): Dimension of input features
|
||||
context_dim (int): Dimension of context features for cross-attention blocks
|
||||
num_heads (int): Number of attention heads
|
||||
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
|
||||
full-attention, then MLP)
|
||||
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||
x_format (str): Input tensor format. Default: "BTHWD"
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||
|
||||
The block_config string uses "-" to separate block types:
|
||||
- "ca"/"cross_attn": Cross-attention block
|
||||
- "fa"/"full_attn": Full self-attention block
|
||||
- "mlp"/"ff": MLP/feedforward block
|
||||
|
||||
Example:
|
||||
block_config = "ca-fa-mlp" creates a sequence of:
|
||||
1. Cross-attention block
|
||||
2. Full self-attention block
|
||||
3. MLP block
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_dim: int,
|
||||
context_dim: int,
|
||||
num_heads: int,
|
||||
block_config: str,
|
||||
mlp_ratio: float = 4.0,
|
||||
x_format: str = "BTHWD",
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
weight_args={},
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList()
|
||||
self.x_format = x_format
|
||||
for block_type in block_config.split("-"):
|
||||
self.blocks.append(
|
||||
DITBuildingBlock(
|
||||
block_type,
|
||||
x_dim,
|
||||
context_dim,
|
||||
num_heads,
|
||||
mlp_ratio,
|
||||
x_format=self.x_format,
|
||||
use_adaln_lora=use_adaln_lora,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb_B_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
x,
|
||||
emb_B_D,
|
||||
crossattn_emb,
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
)
|
||||
return x
|
||||
1041
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
1041
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
File diff suppressed because it is too large
Load Diff
377
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
377
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The patcher and unpatcher implementation for 2D and 3D data.
|
||||
|
||||
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
|
||||
One on the rows and one on the columns.
|
||||
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
|
||||
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
|
||||
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
|
||||
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
|
||||
as we need to support downsampling for more than 2x.
|
||||
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
|
||||
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
_WAVELETS = {
|
||||
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
|
||||
"rearrange": torch.tensor([1.0, 1.0]),
|
||||
}
|
||||
_PERSISTENT = False
|
||||
|
||||
|
||||
class Patcher(torch.nn.Module):
|
||||
"""A module to convert image tensors into patches using torch operations.
|
||||
|
||||
The main difference from `class Patching` is that this module implements
|
||||
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||
|
||||
It's bit-wise identical to the Patching module outputs, with the added
|
||||
benefit of being torch.jit scriptable.
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=1, patch_method="haar"):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.patch_method = patch_method
|
||||
self.register_buffer(
|
||||
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||
)
|
||||
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||
self.register_buffer(
|
||||
"_arange",
|
||||
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||
persistent=_PERSISTENT,
|
||||
)
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.patch_method == "haar":
|
||||
return self._haar(x)
|
||||
elif self.patch_method == "rearrange":
|
||||
return self._arrange(x)
|
||||
else:
|
||||
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||
|
||||
def _dwt(self, x, mode="reflect", rescale=False):
|
||||
dtype = x.dtype
|
||||
h = self.wavelets.to(device=x.device)
|
||||
|
||||
n = h.shape[0]
|
||||
g = x.shape[1]
|
||||
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = hh.to(dtype=dtype)
|
||||
hl = hl.to(dtype=dtype)
|
||||
|
||||
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
||||
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
|
||||
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
|
||||
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||
|
||||
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
|
||||
if rescale:
|
||||
out = out / 2
|
||||
return out
|
||||
|
||||
def _haar(self, x):
|
||||
for _ in self.range:
|
||||
x = self._dwt(x, rescale=True)
|
||||
return x
|
||||
|
||||
def _arrange(self, x):
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (h p1) (w p2) -> b (c p1 p2) h w",
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
).contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class Patcher3D(Patcher):
|
||||
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
|
||||
|
||||
def __init__(self, patch_size=1, patch_method="haar"):
|
||||
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||
self.register_buffer(
|
||||
"patch_size_buffer",
|
||||
patch_size * torch.ones([1], dtype=torch.int32),
|
||||
persistent=_PERSISTENT,
|
||||
)
|
||||
|
||||
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
|
||||
dtype = x.dtype
|
||||
h = self.wavelets.to(device=x.device)
|
||||
|
||||
n = h.shape[0]
|
||||
g = x.shape[1]
|
||||
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = hh.to(dtype=dtype)
|
||||
hl = hl.to(dtype=dtype)
|
||||
|
||||
# Handles temporal axis.
|
||||
x = F.pad(
|
||||
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
|
||||
).to(dtype)
|
||||
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||
|
||||
# Handles spatial axes.
|
||||
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||
|
||||
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
|
||||
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
|
||||
if rescale:
|
||||
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
|
||||
return out
|
||||
|
||||
def _haar(self, x):
|
||||
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||
for _ in self.range:
|
||||
x = self._dwt(x, "haar", rescale=True)
|
||||
return x
|
||||
|
||||
def _arrange(self, x):
|
||||
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
p3=self.patch_size,
|
||||
).contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class UnPatcher(torch.nn.Module):
|
||||
"""A module to convert patches into image tensorsusing torch operations.
|
||||
|
||||
The main difference from `class Unpatching` is that this module implements
|
||||
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||
|
||||
It's bit-wise identical to the Unpatching module outputs, with the added
|
||||
benefit of being torch.jit scriptable.
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=1, patch_method="haar"):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.patch_method = patch_method
|
||||
self.register_buffer(
|
||||
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||
)
|
||||
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||
self.register_buffer(
|
||||
"_arange",
|
||||
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||
persistent=_PERSISTENT,
|
||||
)
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.patch_method == "haar":
|
||||
return self._ihaar(x)
|
||||
elif self.patch_method == "rearrange":
|
||||
return self._iarrange(x)
|
||||
else:
|
||||
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||
|
||||
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||
dtype = x.dtype
|
||||
h = self.wavelets.to(device=x.device)
|
||||
n = h.shape[0]
|
||||
|
||||
g = x.shape[1] // 4
|
||||
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = hh.to(dtype=dtype)
|
||||
hl = hl.to(dtype=dtype)
|
||||
|
||||
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
|
||||
|
||||
# Inverse transform.
|
||||
yl = torch.nn.functional.conv_transpose2d(
|
||||
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||
)
|
||||
yl += torch.nn.functional.conv_transpose2d(
|
||||
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||
)
|
||||
yh = torch.nn.functional.conv_transpose2d(
|
||||
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||
)
|
||||
yh += torch.nn.functional.conv_transpose2d(
|
||||
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||
)
|
||||
y = torch.nn.functional.conv_transpose2d(
|
||||
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||
)
|
||||
y += torch.nn.functional.conv_transpose2d(
|
||||
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||
)
|
||||
|
||||
if rescale:
|
||||
y = y * 2
|
||||
return y
|
||||
|
||||
def _ihaar(self, x):
|
||||
for _ in self.range:
|
||||
x = self._idwt(x, "haar", rescale=True)
|
||||
return x
|
||||
|
||||
def _iarrange(self, x):
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class UnPatcher3D(UnPatcher):
|
||||
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
|
||||
|
||||
def __init__(self, patch_size=1, patch_method="haar"):
|
||||
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||
|
||||
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||
dtype = x.dtype
|
||||
h = self.wavelets.to(device=x.device)
|
||||
|
||||
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
|
||||
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hl = hl.to(dtype=dtype)
|
||||
hh = hh.to(dtype=dtype)
|
||||
|
||||
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||
del x
|
||||
|
||||
# Height height transposed convolutions.
|
||||
xll = F.conv_transpose3d(
|
||||
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlll
|
||||
|
||||
xll += F.conv_transpose3d(
|
||||
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xllh
|
||||
|
||||
xlh = F.conv_transpose3d(
|
||||
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlhl
|
||||
|
||||
xlh += F.conv_transpose3d(
|
||||
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlhh
|
||||
|
||||
xhl = F.conv_transpose3d(
|
||||
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhll
|
||||
|
||||
xhl += F.conv_transpose3d(
|
||||
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhlh
|
||||
|
||||
xhh = F.conv_transpose3d(
|
||||
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhhl
|
||||
|
||||
xhh += F.conv_transpose3d(
|
||||
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhhh
|
||||
|
||||
# Handles width transposed convolutions.
|
||||
xl = F.conv_transpose3d(
|
||||
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xll
|
||||
|
||||
xl += F.conv_transpose3d(
|
||||
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xlh
|
||||
|
||||
xh = F.conv_transpose3d(
|
||||
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xhl
|
||||
|
||||
xh += F.conv_transpose3d(
|
||||
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xhh
|
||||
|
||||
# Handles time axis transposed convolutions.
|
||||
x = F.conv_transpose3d(
|
||||
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||
)
|
||||
del xl
|
||||
|
||||
x += F.conv_transpose3d(
|
||||
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||
)
|
||||
|
||||
if rescale:
|
||||
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
|
||||
return x
|
||||
|
||||
def _ihaar(self, x):
|
||||
for _ in self.range:
|
||||
x = self._idwt(x, "haar", rescale=True)
|
||||
x = x[:, :, self.patch_size - 1 :, ...]
|
||||
return x
|
||||
|
||||
def _iarrange(self, x):
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
p3=self.patch_size,
|
||||
)
|
||||
x = x[:, :, self.patch_size - 1 :, ...]
|
||||
return x
|
||||
112
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
112
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared utilities for the networks module."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||
batch_size = x.shape[0]
|
||||
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
|
||||
|
||||
|
||||
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
||||
|
||||
|
||||
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||
batch_size, height = x.shape[0], x.shape[-2]
|
||||
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
|
||||
|
||||
|
||||
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
|
||||
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
|
||||
|
||||
|
||||
def cast_tuple(t: Any, length: int = 1) -> Any:
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def replication_pad(x):
|
||||
return torch.cat([x[:, :, :1, ...], x], dim=2)
|
||||
|
||||
|
||||
def divisible_by(num: int, den: int) -> bool:
|
||||
return (num % den) == 0
|
||||
|
||||
|
||||
def is_odd(n: int) -> bool:
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return ops.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class CausalNormalize(torch.nn.Module):
|
||||
def __init__(self, in_channels, num_groups=1):
|
||||
super().__init__()
|
||||
self.norm = ops.GroupNorm(
|
||||
num_groups=num_groups,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
)
|
||||
self.num_groups = num_groups
|
||||
|
||||
def forward(self, x):
|
||||
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
|
||||
# All new models should use num_groups=1, otherwise causality is not guaranteed.
|
||||
if self.num_groups == 1:
|
||||
x, batch_size = time2batch(x)
|
||||
return batch2time(self.norm(x), batch_size)
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def default(*args):
|
||||
for arg in args:
|
||||
if exists(arg):
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def round_ste(z: torch.Tensor) -> torch.Tensor:
|
||||
"""Round with straight through gradients."""
|
||||
zhat = z.round()
|
||||
return z + (zhat - z).detach()
|
||||
|
||||
|
||||
def log(t, eps=1e-5):
|
||||
return t.clamp(min=eps).log()
|
||||
|
||||
|
||||
def entropy(prob):
|
||||
return (-prob * log(prob)).sum(dim=-1)
|
||||
512
comfy/ldm/cosmos/model.py
Normal file
512
comfy/ldm/cosmos/model.py
Normal file
@@ -0,0 +1,512 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
PatchEmbed,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
|
||||
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
|
||||
|
||||
class GeneralDIT(nn.Module):
|
||||
"""
|
||||
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||
|
||||
Args:
|
||||
max_img_h (int): Maximum height of the input images.
|
||||
max_img_w (int): Maximum width of the input images.
|
||||
max_frames (int): Maximum number of frames in the video sequence.
|
||||
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||
out_channels (int): Number of output channels.
|
||||
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||
block_config (str): Configuration of the transformer block. See Notes for supported block types.
|
||||
model_channels (int): Base number of channels used throughout the model.
|
||||
num_blocks (int): Number of transformer blocks.
|
||||
num_heads (int): Number of heads in the multi-head attention layers.
|
||||
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
|
||||
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
|
||||
pos_emb_cls (str): Type of positional embeddings.
|
||||
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||
affline_emb_norm (bool): Whether to normalize affine embeddings.
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
|
||||
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||
|
||||
Notes:
|
||||
Supported block types in block_config:
|
||||
* cross_attn, ca: Cross attention
|
||||
* full_attn: Full attention on all flattened tokens
|
||||
* mlp, ff: Feed forward block
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_img_h: int,
|
||||
max_img_w: int,
|
||||
max_frames: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
patch_spatial: tuple,
|
||||
patch_temporal: int,
|
||||
concat_padding_mask: bool = True,
|
||||
# attention settings
|
||||
block_config: str = "FA-CA-MLP",
|
||||
model_channels: int = 768,
|
||||
num_blocks: int = 10,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
block_x_format: str = "BTHWD",
|
||||
# cross attention settings
|
||||
crossattn_emb_channels: int = 1024,
|
||||
use_cross_attn_mask: bool = False,
|
||||
# positional embedding settings
|
||||
pos_emb_cls: str = "sincos",
|
||||
pos_emb_learnable: bool = False,
|
||||
pos_emb_interpolation: str = "crop",
|
||||
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
rope_h_extrapolation_ratio: float = 1.0,
|
||||
rope_w_extrapolation_ratio: float = 1.0,
|
||||
rope_t_extrapolation_ratio: float = 1.0,
|
||||
extra_per_block_abs_pos_emb: bool = False,
|
||||
extra_per_block_abs_pos_emb_type: str = "sincos",
|
||||
extra_h_extrapolation_ratio: float = 1.0,
|
||||
extra_w_extrapolation_ratio: float = 1.0,
|
||||
extra_t_extrapolation_ratio: float = 1.0,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_img_h = max_img_h
|
||||
self.max_img_w = max_img_w
|
||||
self.max_frames = max_frames
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_spatial = patch_spatial
|
||||
self.patch_temporal = patch_temporal
|
||||
self.num_heads = num_heads
|
||||
self.num_blocks = num_blocks
|
||||
self.model_channels = model_channels
|
||||
self.use_cross_attn_mask = use_cross_attn_mask
|
||||
self.concat_padding_mask = concat_padding_mask
|
||||
# positional embedding settings
|
||||
self.pos_emb_cls = pos_emb_cls
|
||||
self.pos_emb_learnable = pos_emb_learnable
|
||||
self.pos_emb_interpolation = pos_emb_interpolation
|
||||
self.affline_emb_norm = affline_emb_norm
|
||||
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
|
||||
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||
self.dtype = dtype
|
||||
weight_args = {"device": device, "dtype": dtype}
|
||||
|
||||
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||
self.x_embedder = PatchEmbed(
|
||||
spatial_patch_size=patch_spatial,
|
||||
temporal_patch_size=patch_temporal,
|
||||
in_channels=in_channels,
|
||||
out_channels=model_channels,
|
||||
bias=False,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.build_pos_embed(device=device, dtype=dtype)
|
||||
self.block_x_format = block_x_format
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
self.t_embedder = nn.ModuleList(
|
||||
[Timesteps(model_channels),
|
||||
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleDict()
|
||||
|
||||
for idx in range(num_blocks):
|
||||
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
|
||||
x_dim=model_channels,
|
||||
context_dim=crossattn_emb_channels,
|
||||
num_heads=num_heads,
|
||||
block_config=block_config,
|
||||
mlp_ratio=mlp_ratio,
|
||||
x_format=self.block_x_format,
|
||||
use_adaln_lora=use_adaln_lora,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
if self.affline_emb_norm:
|
||||
logging.debug("Building affine embedding normalization layer")
|
||||
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||
else:
|
||||
self.affline_norm = nn.Identity()
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=self.model_channels,
|
||||
spatial_patch_size=self.patch_spatial,
|
||||
temporal_patch_size=self.patch_temporal,
|
||||
out_channels=self.out_channels,
|
||||
use_adaln_lora=self.use_adaln_lora,
|
||||
adaln_lora_dim=self.adaln_lora_dim,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def build_pos_embed(self, device=None, dtype=None):
|
||||
if self.pos_emb_cls == "rope3d":
|
||||
cls_type = VideoRopePosition3DEmb
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||
|
||||
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
kwargs = dict(
|
||||
model_channels=self.model_channels,
|
||||
len_h=self.max_img_h // self.patch_spatial,
|
||||
len_w=self.max_img_w // self.patch_spatial,
|
||||
len_t=self.max_frames // self.patch_temporal,
|
||||
is_learnable=self.pos_emb_learnable,
|
||||
interpolation=self.pos_emb_interpolation,
|
||||
head_dim=self.model_channels // self.num_heads,
|
||||
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||
device=device,
|
||||
)
|
||||
self.pos_embedder = cls_type(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
assert self.extra_per_block_abs_pos_emb_type in [
|
||||
"learnable",
|
||||
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
|
||||
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||
kwargs["device"] = device
|
||||
kwargs["dtype"] = dtype
|
||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def prepare_embedded_sequence(
|
||||
self,
|
||||
x_B_C_T_H_W: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||
|
||||
Args:
|
||||
x_B_C_T_H_W (torch.Tensor): video
|
||||
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||
If None, a default value (`self.base_fps`) will be used.
|
||||
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||
|
||||
Notes:
|
||||
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||
the `self.pos_embedder` with the shape [T, H, W].
|
||||
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||
`self.pos_embedder` with the fps tensor.
|
||||
- Otherwise, the positional embeddings are generated without considering fps.
|
||||
"""
|
||||
if self.concat_padding_mask:
|
||||
if padding_mask is not None:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
else:
|
||||
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||
|
||||
x_B_C_T_H_W = torch.cat(
|
||||
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||
)
|
||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||
else:
|
||||
extra_pos_emb = None
|
||||
|
||||
if "rope" in self.pos_emb_cls.lower():
|
||||
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||
|
||||
if "fps_aware" in self.pos_emb_cls:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||
else:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||
|
||||
return x_B_T_H_W_D, None, extra_pos_emb
|
||||
|
||||
def decoder_head(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
del crossattn_emb, crossattn_mask
|
||||
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
|
||||
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
|
||||
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
|
||||
# This is to ensure x_BT_HW_D has the correct shape because
|
||||
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
|
||||
x_BT_HW_D = x_BT_HW_D.view(
|
||||
B * T_before_patchify // self.patch_temporal,
|
||||
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
|
||||
-1,
|
||||
)
|
||||
x_B_D_T_H_W = rearrange(
|
||||
x_BT_HW_D,
|
||||
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||
p1=self.patch_spatial,
|
||||
p2=self.patch_spatial,
|
||||
H=H_before_patchify // self.patch_spatial,
|
||||
W=W_before_patchify // self.patch_spatial,
|
||||
t=self.patch_temporal,
|
||||
B=B,
|
||||
)
|
||||
return x_B_D_T_H_W
|
||||
|
||||
def forward_before_blocks(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||
timesteps: (B, ) tensor of timesteps
|
||||
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||
"""
|
||||
del kwargs
|
||||
assert isinstance(
|
||||
data_type, DataType
|
||||
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
|
||||
original_shape = x.shape
|
||||
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||
x,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
latent_condition=latent_condition,
|
||||
latent_condition_sigma=latent_condition_sigma,
|
||||
)
|
||||
# logging affline scale information
|
||||
affline_scale_log_info = {}
|
||||
|
||||
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
|
||||
affline_emb_B_D = timesteps_B_D
|
||||
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
|
||||
|
||||
if scalar_feature is not None:
|
||||
raise NotImplementedError("Scalar feature is not implemented yet.")
|
||||
|
||||
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
|
||||
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
|
||||
|
||||
if self.use_cross_attn_mask:
|
||||
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
|
||||
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
|
||||
else:
|
||||
crossattn_mask = None
|
||||
|
||||
if self.blocks["block0"].x_format == "THWBD":
|
||||
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
|
||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
|
||||
)
|
||||
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
|
||||
|
||||
if crossattn_mask:
|
||||
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
|
||||
|
||||
elif self.blocks["block0"].x_format == "BTHWD":
|
||||
x = x_B_T_H_W_D
|
||||
else:
|
||||
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
|
||||
output = {
|
||||
"x": x,
|
||||
"affline_emb_B_D": affline_emb_B_D,
|
||||
"crossattn_emb": crossattn_emb,
|
||||
"crossattn_mask": crossattn_mask,
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
|
||||
"adaln_lora_B_3D": adaln_lora_B_3D,
|
||||
"original_shape": original_shape,
|
||||
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
}
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# crossattn_emb: torch.Tensor,
|
||||
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||
timesteps: (B, ) tensor of timesteps
|
||||
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
|
||||
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
|
||||
we need forward_before_blocks pass to the forward_before_blocks function.
|
||||
"""
|
||||
|
||||
crossattn_emb = context
|
||||
crossattn_mask = attention_mask
|
||||
|
||||
inputs = self.forward_before_blocks(
|
||||
x=x,
|
||||
timesteps=timesteps,
|
||||
crossattn_emb=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
fps=fps,
|
||||
image_size=image_size,
|
||||
padding_mask=padding_mask,
|
||||
scalar_feature=scalar_feature,
|
||||
data_type=data_type,
|
||||
latent_condition=latent_condition,
|
||||
latent_condition_sigma=latent_condition_sigma,
|
||||
condition_video_augment_sigma=condition_video_augment_sigma,
|
||||
**kwargs,
|
||||
)
|
||||
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
|
||||
inputs["x"],
|
||||
inputs["affline_emb_B_D"],
|
||||
inputs["crossattn_emb"],
|
||||
inputs["crossattn_mask"],
|
||||
inputs["rope_emb_L_1_1_D"],
|
||||
inputs["adaln_lora_B_3D"],
|
||||
inputs["original_shape"],
|
||||
)
|
||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||
del inputs
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
assert (
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||
|
||||
for _, block in self.blocks.items():
|
||||
assert (
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
||||
x = block(
|
||||
x,
|
||||
affline_emb_B_D,
|
||||
crossattn_emb,
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
||||
x_B_D_T_H_W = self.decoder_head(
|
||||
x_B_T_H_W_D=x_B_T_H_W_D,
|
||||
emb_B_D=affline_emb_B_D,
|
||||
crossattn_emb=None,
|
||||
origin_shape=original_shape,
|
||||
crossattn_mask=None,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
)
|
||||
|
||||
return x_B_D_T_H_W
|
||||
208
comfy/ldm/cosmos/position_embedding.py
Normal file
208
comfy/ldm/cosmos/position_embedding.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
|
||||
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
||||
"""
|
||||
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to normalize.
|
||||
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
||||
eps (float, optional): A small constant to ensure numerical stability during division.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
if dim is None:
|
||||
dim = list(range(1, x.ndim))
|
||||
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||
return x / norm.to(x.dtype)
|
||||
|
||||
|
||||
class VideoPositionEmb(nn.Module):
|
||||
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||
"""
|
||||
It delegates the embedding generation to generate_embeddings function.
|
||||
"""
|
||||
B_T_H_W_C = x_B_T_H_W_C.shape
|
||||
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
|
||||
|
||||
return embeddings
|
||||
|
||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
def __init__(
|
||||
self,
|
||||
*, # enforce keyword arguments
|
||||
head_dim: int,
|
||||
len_h: int,
|
||||
len_w: int,
|
||||
len_t: int,
|
||||
base_fps: int = 24,
|
||||
h_extrapolation_ratio: float = 1.0,
|
||||
w_extrapolation_ratio: float = 1.0,
|
||||
t_extrapolation_ratio: float = 1.0,
|
||||
device=None,
|
||||
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||
):
|
||||
del kwargs
|
||||
super().__init__()
|
||||
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
||||
self.base_fps = base_fps
|
||||
self.max_h = len_h
|
||||
self.max_w = len_w
|
||||
|
||||
dim = head_dim
|
||||
dim_h = dim // 6 * 2
|
||||
dim_w = dim_h
|
||||
dim_t = dim - 2 * dim_h
|
||||
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
||||
self.register_buffer(
|
||||
"dim_spatial_range",
|
||||
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"dim_temporal_range",
|
||||
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
||||
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
||||
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
||||
|
||||
def generate_embeddings(
|
||||
self,
|
||||
B_T_H_W_C: torch.Size,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
h_ntk_factor: Optional[float] = None,
|
||||
w_ntk_factor: Optional[float] = None,
|
||||
t_ntk_factor: Optional[float] = None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Generate embeddings for the given input size.
|
||||
|
||||
Args:
|
||||
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
||||
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
||||
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
||||
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
||||
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
||||
|
||||
Returns:
|
||||
Not specified in the original code snippet.
|
||||
"""
|
||||
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
||||
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
||||
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
||||
|
||||
h_theta = 10000.0 * h_ntk_factor
|
||||
w_theta = 10000.0 * w_ntk_factor
|
||||
t_theta = 10000.0 * t_ntk_factor
|
||||
|
||||
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
||||
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
||||
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||
|
||||
B, T, H, W, _ = B_T_H_W_C
|
||||
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||
assert (
|
||||
uniform_fps or B == 1 or T == 1
|
||||
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||
assert (
|
||||
H <= self.max_h and W <= self.max_w
|
||||
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
||||
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
||||
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||
|
||||
# apply sequence scaling in temporal dimension
|
||||
if fps is None: # image case
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||
else:
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
|
||||
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
||||
|
||||
em_T_H_W_D = torch.cat(
|
||||
[
|
||||
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
||||
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
||||
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
||||
]
|
||||
, dim=-2,
|
||||
)
|
||||
|
||||
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
||||
|
||||
|
||||
class LearnablePosEmbAxis(VideoPositionEmb):
|
||||
def __init__(
|
||||
self,
|
||||
*, # enforce keyword arguments
|
||||
interpolation: str,
|
||||
model_channels: int,
|
||||
len_h: int,
|
||||
len_w: int,
|
||||
len_t: int,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
||||
"""
|
||||
del kwargs # unused
|
||||
super().__init__()
|
||||
self.interpolation = interpolation
|
||||
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||
|
||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||
|
||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||
B, T, H, W, _ = B_T_H_W_C
|
||||
if self.interpolation == "crop":
|
||||
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
|
||||
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
|
||||
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
|
||||
emb = (
|
||||
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
||||
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
||||
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
||||
)
|
||||
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
||||
else:
|
||||
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
||||
|
||||
return normalize(emb, dim=-1, eps=1e-6)
|
||||
131
comfy/ldm/cosmos/vae.py
Normal file
131
comfy/ldm/cosmos/vae.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
from enum import Enum
|
||||
import math
|
||||
|
||||
from .cosmos_tokenizer.layers3d import (
|
||||
EncoderFactorized,
|
||||
DecoderFactorized,
|
||||
CausalConv3d,
|
||||
)
|
||||
|
||||
|
||||
class IdentityDistribution(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, parameters):
|
||||
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
|
||||
|
||||
|
||||
class GaussianDistribution(torch.nn.Module):
|
||||
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
|
||||
super().__init__()
|
||||
self.min_logvar = min_logvar
|
||||
self.max_logvar = max_logvar
|
||||
|
||||
def sample(self, mean, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
|
||||
def forward(self, parameters):
|
||||
mean, logvar = torch.chunk(parameters, 2, dim=1)
|
||||
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
|
||||
return self.sample(mean, logvar), (mean, logvar)
|
||||
|
||||
|
||||
class ContinuousFormulation(Enum):
|
||||
VAE = GaussianDistribution
|
||||
AE = IdentityDistribution
|
||||
|
||||
|
||||
class CausalContinuousVideoTokenizer(nn.Module):
|
||||
def __init__(
|
||||
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
|
||||
self.latent_channels = latent_channels
|
||||
self.sigma_data = 0.5
|
||||
|
||||
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
|
||||
self.encoder = EncoderFactorized(
|
||||
z_channels=z_factor * z_channels, **kwargs
|
||||
)
|
||||
if kwargs.get("temporal_compression", 4) == 4:
|
||||
kwargs["channels_mult"] = [2, 4]
|
||||
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
|
||||
self.decoder = DecoderFactorized(
|
||||
z_channels=z_channels, **kwargs
|
||||
)
|
||||
|
||||
self.quant_conv = CausalConv3d(
|
||||
z_factor * z_channels,
|
||||
z_factor * latent_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
self.post_quant_conv = CausalConv3d(
|
||||
latent_channels, z_channels, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
|
||||
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
||||
|
||||
num_parameters = sum(param.numel() for param in self.parameters())
|
||||
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||
logging.debug(
|
||||
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
||||
)
|
||||
|
||||
latent_temporal_chunk = 16
|
||||
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
z, posteriors = self.distribution(moments)
|
||||
latent_ch = z.shape[1]
|
||||
latent_t = z.shape[2]
|
||||
in_dtype = z.dtype
|
||||
mean = self.latent_mean.view(latent_ch, -1)
|
||||
std = self.latent_std.view(latent_ch, -1)
|
||||
|
||||
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||
return ((z - mean) / std) * self.sigma_data
|
||||
|
||||
def decode(self, z):
|
||||
in_dtype = z.dtype
|
||||
latent_ch = z.shape[1]
|
||||
latent_t = z.shape[2]
|
||||
mean = self.latent_mean.view(latent_ch, -1)
|
||||
std = self.latent_std.view(latent_ch, -1)
|
||||
|
||||
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||
|
||||
z = z / self.sigma_data
|
||||
z = z * std + mean
|
||||
z = self.post_quant_conv(z)
|
||||
return self.decoder(z)
|
||||
|
||||
@@ -6,9 +6,7 @@ import math
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
from .layers import (timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
@@ -105,7 +105,9 @@ class Modulation(nn.Module):
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
if vec.ndim == 2:
|
||||
vec = vec[:, None, :]
|
||||
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
@@ -113,8 +115,22 @@ class Modulation(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
if modulation_dims is None:
|
||||
if m_add is not None:
|
||||
return tensor * m_mult + m_add
|
||||
else:
|
||||
return tensor * m_mult
|
||||
else:
|
||||
for d in modulation_dims:
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||
if m_add is not None:
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||
return tensor
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@@ -141,39 +157,50 @@ class DoubleStreamBlock(nn.Module):
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||
if self.flipped_img_txt:
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -217,19 +244,18 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@@ -242,8 +268,11 @@ class LastLayer(nn.Module):
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
vec = vec[:, None, :]
|
||||
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
|
||||
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -1,20 +1,29 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
@@ -28,8 +37,9 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -14,9 +16,6 @@ from .layers import (
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
@@ -98,8 +97,9 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
control = None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -109,29 +109,44 @@ class Flux(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
else:
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -146,13 +161,20 @@ class Flux(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -166,7 +188,7 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
@@ -181,5 +203,5 @@ class Flux(nn.Module):
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
||||
@@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
@@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||
@@ -461,8 +460,6 @@ class AsymmDiTJoint(nn.Module):
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||
assert x.ndim == 3
|
||||
B = x.size(0)
|
||||
|
||||
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
N = T * pH * pW
|
||||
|
||||
@@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
|
||||
802
comfy/ldm/hidream/model.py
Normal file
802
comfy/ldm/hidream/model.py
Normal file
@@ -0,0 +1,802 @@
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.flux.layers import LastLayer
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(2)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=1024,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent)
|
||||
return latent
|
||||
|
||||
|
||||
class PooledEmbed(nn.Module):
|
||||
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, pooled_embed):
|
||||
return self.pooled_embedder(pooled_embed)
|
||||
|
||||
|
||||
class TimestepEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timesteps, wdtype):
|
||||
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||
t_emb = self.timestep_embedder(t_emb)
|
||||
return t_emb
|
||||
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||
|
||||
|
||||
class HiDreamAttnProcessor_flashattn:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
dtype = image_tokens.dtype
|
||||
batch_size = image_tokens.shape[0]
|
||||
|
||||
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
|
||||
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
|
||||
value_i = attn.to_v(image_tokens)
|
||||
|
||||
inner_dim = key_i.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
if image_tokens_masks is not None:
|
||||
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
|
||||
|
||||
if not attn.single:
|
||||
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
|
||||
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
|
||||
value_t = attn.to_v_t(text_tokens)
|
||||
|
||||
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
num_image_tokens = query_i.shape[1]
|
||||
num_text_tokens = query_t.shape[1]
|
||||
query = torch.cat([query_i, query_t], dim=1)
|
||||
key = torch.cat([key_i, key_t], dim=1)
|
||||
value = torch.cat([value_i, value_t], dim=1)
|
||||
else:
|
||||
query = query_i
|
||||
key = key_i
|
||||
value = value_i
|
||||
|
||||
if query.shape[-1] == rope.shape[-3] * 2:
|
||||
query, key = apply_rope(query, key, rope)
|
||||
else:
|
||||
query_1, query_2 = query.chunk(2, dim=-1)
|
||||
key_1, key_2 = key.chunk(2, dim=-1)
|
||||
query_1, key_1 = apply_rope(query_1, key_1, rope)
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = attention(query, key, value)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
hidden_states_i = attn.to_out(hidden_states_i)
|
||||
hidden_states_t = attn.to_out_t(hidden_states_t)
|
||||
return hidden_states_i, hidden_states_t
|
||||
else:
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class HiDreamAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
scale_qk: bool = True,
|
||||
eps: float = 1e-5,
|
||||
processor = None,
|
||||
out_dim: int = None,
|
||||
single: bool = False,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
# super(Attention, self).__init__()
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.sliceable_head_dim = heads
|
||||
self.single = single
|
||||
|
||||
linear_cls = operations.Linear
|
||||
self.linear_cls = linear_cls
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
||||
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
|
||||
if not single:
|
||||
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
||||
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
norm_image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: torch.FloatTensor = None,
|
||||
norm_text_tokens: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
image_tokens = norm_image_tokens,
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = norm_text_tokens,
|
||||
rope = rope,
|
||||
)
|
||||
|
||||
|
||||
class FeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * (
|
||||
(hidden_dim + multiple_of - 1) // multiple_of
|
||||
)
|
||||
|
||||
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.top_k = num_activated_experts
|
||||
self.n_routed_experts = num_routed_experts
|
||||
|
||||
self.scoring_func = 'softmax'
|
||||
self.alpha = aux_loss_alpha
|
||||
self.seq_aux = False
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = False
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
pass
|
||||
# import torch.nn.init as init
|
||||
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
### select top-k experts
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
aux_loss = None
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MOEFeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
num_routed_experts: int,
|
||||
num_activated_experts: int,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
|
||||
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
|
||||
self.gate = MoEGate(
|
||||
embed_dim = dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.num_activated_experts = num_activated_experts
|
||||
|
||||
def forward(self, x):
|
||||
wtype = x.dtype
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if True: # self.training: # TODO: check which branch performs faster
|
||||
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||
y = torch.empty_like(x, dtype=wtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape).to(dtype=wtype)
|
||||
#y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.num_activated_experts
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# for fp16 and other dtype
|
||||
expert_cache = expert_cache.to(expert_out.dtype)
|
||||
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
|
||||
return expert_cache
|
||||
|
||||
|
||||
class TextProjection(nn.Module):
|
||||
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear(caption)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlockType:
|
||||
TransformerBlock = 1
|
||||
SingleTransformerBlock = 2
|
||||
|
||||
|
||||
class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
# 1. Attention
|
||||
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.attn1 = HiDreamAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor = HiDreamAttnProcessor_flashattn(),
|
||||
single = True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
if num_routed_experts > 0:
|
||||
self.ff_i = MOEFeedForwardSwiGLU(
|
||||
dim = dim,
|
||||
hidden_dim = 4 * dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
|
||||
|
||||
# 1. MM-Attention
|
||||
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
||||
attn_output_i = self.attn1(
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
rope = rope,
|
||||
)
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
|
||||
image_tokens = ff_output_i + image_tokens
|
||||
return image_tokens
|
||||
|
||||
|
||||
class HiDreamImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
# nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
# nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
# 1. Attention
|
||||
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.attn1 = HiDreamAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor = HiDreamAttnProcessor_flashattn(),
|
||||
single = False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
if num_routed_experts > 0:
|
||||
self.ff_i = MOEFeedForwardSwiGLU(
|
||||
dim = dim,
|
||||
hidden_dim = 4 * dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
|
||||
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
|
||||
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
|
||||
|
||||
# 1. MM-Attention
|
||||
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
||||
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
|
||||
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
|
||||
|
||||
attn_output_i, attn_output_t = self.attn1(
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
norm_text_tokens,
|
||||
rope = rope,
|
||||
)
|
||||
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
text_tokens = gate_msa_t * attn_output_t + text_tokens
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
||||
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
|
||||
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
|
||||
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
|
||||
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
|
||||
image_tokens = ff_output_i + image_tokens
|
||||
text_tokens = ff_output_t + text_tokens
|
||||
return image_tokens, text_tokens
|
||||
|
||||
|
||||
class HiDreamImageBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
block_type: BlockType = BlockType.TransformerBlock,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
block_classes = {
|
||||
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
|
||||
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
|
||||
}
|
||||
self.block = block_classes[block_type](
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
num_routed_experts,
|
||||
num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
return self.block(
|
||||
image_tokens,
|
||||
image_tokens_masks,
|
||||
text_tokens,
|
||||
adaln_input,
|
||||
rope,
|
||||
)
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Optional[int] = None,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 16,
|
||||
num_single_layers: int = 32,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 20,
|
||||
caption_channels: List[int] = None,
|
||||
text_emb_dim: int = 2048,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
axes_dims_rope: Tuple[int, int] = (32, 32),
|
||||
max_resolution: Tuple[int, int] = (128, 128),
|
||||
llama_layers: List[int] = None,
|
||||
image_model=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_single_layers = num_single_layers
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = self.num_attention_heads * self.attention_head_dim
|
||||
self.llama_layers = llama_layers
|
||||
|
||||
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size = patch_size,
|
||||
in_channels = in_channels,
|
||||
out_channels = self.inner_dim,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
|
||||
|
||||
self.double_stream_blocks = nn.ModuleList(
|
||||
[
|
||||
HiDreamImageBlock(
|
||||
dim = self.inner_dim,
|
||||
num_attention_heads = self.num_attention_heads,
|
||||
attention_head_dim = self.attention_head_dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
block_type = BlockType.TransformerBlock,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_stream_blocks = nn.ModuleList(
|
||||
[
|
||||
HiDreamImageBlock(
|
||||
dim = self.inner_dim,
|
||||
num_attention_heads = self.num_attention_heads,
|
||||
attention_head_dim = self.attention_head_dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
block_type = BlockType.SingleTransformerBlock,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for i in range(self.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
||||
caption_projection = []
|
||||
for caption_channel in caption_channels:
|
||||
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
|
||||
|
||||
def expand_timesteps(self, timesteps, batch_size, device):
|
||||
if not torch.is_tensor(timesteps):
|
||||
is_mps = device.type == "mps"
|
||||
if isinstance(timesteps, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
return timesteps
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
|
||||
x_arr = []
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
pH, pW = img_size
|
||||
x_arr.append(
|
||||
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
||||
p1=self.patch_size, p2=self.patch_size)
|
||||
)
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
return x
|
||||
|
||||
def patchify(self, x, max_seq, img_sizes=None):
|
||||
pz2 = self.patch_size * self.patch_size
|
||||
if isinstance(x, torch.Tensor):
|
||||
B = x.shape[0]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
else:
|
||||
B = len(x)
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
|
||||
|
||||
if img_sizes is not None:
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
x_masks[i, 0:img_size[0] * img_size[1]] = 1
|
||||
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
|
||||
elif isinstance(x, torch.Tensor):
|
||||
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
||||
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
|
||||
img_sizes = [[pH, pW]] * B
|
||||
x_masks = None
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
bs, c, h, w = x.shape
|
||||
if image_cond is not None:
|
||||
x = torch.cat([x, image_cond], dim=-1)
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
timesteps = t
|
||||
pooled_embeds = y
|
||||
T5_encoder_hidden_states = context
|
||||
|
||||
img_sizes = None
|
||||
|
||||
# spatial forward
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states_type = hidden_states.dtype
|
||||
|
||||
# 0. time
|
||||
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
|
||||
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||
p_embedder = self.p_embedder(pooled_embeds)
|
||||
adaln_input = timesteps + p_embedder
|
||||
|
||||
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
|
||||
if image_tokens_masks is None:
|
||||
pH, pW = img_sizes[0]
|
||||
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# T5_encoder_hidden_states = encoder_hidden_states[0]
|
||||
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
|
||||
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
new_encoder_hidden_states = []
|
||||
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
||||
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
||||
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
||||
new_encoder_hidden_states.append(enc_hidden_state)
|
||||
encoder_hidden_states = new_encoder_hidden_states
|
||||
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
|
||||
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
encoder_hidden_states.append(T5_encoder_hidden_states)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
batch_size,
|
||||
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
|
||||
3,
|
||||
device=img_ids.device, dtype=img_ids.dtype
|
||||
)
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
rope = self.pe_embedder(ids)
|
||||
|
||||
# 2. Blocks
|
||||
block_id = 0
|
||||
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||
for bid, block in enumerate(self.double_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
hidden_states, initial_encoder_hidden_states = block(
|
||||
image_tokens = hidden_states,
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = cur_encoder_hidden_states,
|
||||
adaln_input = adaln_input,
|
||||
rope = rope,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
image_tokens_seq_len = hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
|
||||
hidden_states_seq_len = hidden_states.shape[1]
|
||||
if image_tokens_masks is not None:
|
||||
encoder_attention_mask_ones = torch.ones(
|
||||
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
|
||||
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
|
||||
)
|
||||
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
|
||||
|
||||
for bid, block in enumerate(self.single_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
hidden_states = block(
|
||||
image_tokens=hidden_states,
|
||||
image_tokens_masks=image_tokens_masks,
|
||||
text_tokens=None,
|
||||
adaln_input=adaln_input,
|
||||
rope=rope,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||
output = self.final_layer(hidden_states, adaln_input)
|
||||
output = self.unpatchify(output, img_sizes)
|
||||
return -output[:, :, :h, :w]
|
||||
135
comfy/ldm/hunyuan3d/model.py
Normal file
135
comfy/ldm/hunyuan3d/model.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from comfy.ldm.flux.layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=64,
|
||||
context_in_dim=1536,
|
||||
hidden_size=1024,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=16,
|
||||
depth=16,
|
||||
depth_single_blocks=32,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
||||
)
|
||||
|
||||
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
|
||||
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
|
||||
)
|
||||
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(depth_single_blocks)
|
||||
]
|
||||
)
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
img = self.latent_in(x)
|
||||
|
||||
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
|
||||
if self.guidance_in is not None:
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
|
||||
|
||||
txt = self.cond_in(txt)
|
||||
pe = None
|
||||
attn_mask = None
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
return img.movedim(-2, -1) * (-1.0)
|
||||
587
comfy/ldm/hunyuan3d/vae.py
Normal file
587
comfy/ldm/hunyuan3d/vae.py
Normal file
@@ -0,0 +1,587 @@
|
||||
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
|
||||
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
[
|
||||
sin(x[..., i]),
|
||||
sin(f_1*x[..., i]),
|
||||
sin(f_2*x[..., i]),
|
||||
...
|
||||
sin(f_N * x[..., i]),
|
||||
cos(x[..., i]),
|
||||
cos(f_1*x[..., i]),
|
||||
cos(f_2*x[..., i]),
|
||||
...
|
||||
cos(f_N * x[..., i]),
|
||||
x[..., i] # only present if include_input is True.
|
||||
], here f_i is the frequency.
|
||||
|
||||
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
||||
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
||||
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
||||
|
||||
Args:
|
||||
num_freqs (int): the number of frequencies, default is 6;
|
||||
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
||||
input_dim (int): the input dimension, default is 3;
|
||||
include_input (bool): include the input tensor or not, default is True.
|
||||
|
||||
Attributes:
|
||||
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
||||
|
||||
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
||||
otherwise, it is input_dim * num_freqs * 2.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_freqs: int = 6,
|
||||
logspace: bool = True,
|
||||
input_dim: int = 3,
|
||||
include_input: bool = True,
|
||||
include_pi: bool = True) -> None:
|
||||
|
||||
"""The initialization"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if logspace:
|
||||
frequencies = 2.0 ** torch.arange(
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
frequencies = torch.linspace(
|
||||
1.0,
|
||||
2.0 ** (num_freqs - 1),
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
if include_pi:
|
||||
frequencies *= torch.pi
|
||||
|
||||
self.register_buffer("frequencies", frequencies, persistent=False)
|
||||
self.include_input = include_input
|
||||
self.num_freqs = num_freqs
|
||||
|
||||
self.out_dim = self.get_dims(input_dim)
|
||||
|
||||
def get_dims(self, input_dim):
|
||||
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
||||
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
||||
|
||||
return out_dim
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" Forward process.
|
||||
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
|
||||
Returns:
|
||||
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
||||
where temp is 1 if include_input is True and 0 otherwise.
|
||||
"""
|
||||
|
||||
if self.num_freqs > 0:
|
||||
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
|
||||
if self.include_input:
|
||||
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if self.drop_prob == 0. or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and self.scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
def extra_repr(self):
|
||||
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, *,
|
||||
width: int,
|
||||
expand_ratio: int = 4,
|
||||
output_width: int = None,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.c_fc = ops.Linear(width, width * expand_ratio)
|
||||
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
||||
self.gelu = nn.GELU()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
data_width: Optional[int] = None,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
kv_cache: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.data_width = width if data_width is None else data_width
|
||||
self.c_q = ops.Linear(width, width, bias=qkv_bias)
|
||||
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadCrossAttention(
|
||||
heads=heads,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.kv_cache = kv_cache
|
||||
self.data = None
|
||||
|
||||
def forward(self, x, data):
|
||||
x = self.c_q(x)
|
||||
if self.kv_cache:
|
||||
if self.data is None:
|
||||
self.data = self.c_kv(data)
|
||||
logging.info('Save kv cache,this should be called only once for one mesh')
|
||||
data = self.data
|
||||
else:
|
||||
data = self.c_kv(data)
|
||||
x = self.attention(x, data)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
data_width: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_width is None:
|
||||
data_width = width
|
||||
|
||||
self.attn = MultiheadCrossAttention(
|
||||
width=width,
|
||||
heads=heads,
|
||||
data_width=data_width,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
||||
|
||||
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class QKVMultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, qkv):
|
||||
bs, n_ctx, width = qkv.shape
|
||||
attn_ch = width // self.heads // 3
|
||||
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||||
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
heads=heads,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_qkv(x)
|
||||
x = self.attention(x)
|
||||
x = self.drop_path(self.c_proj(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = MultiheadAttention(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
for block in self.resblocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
downsample_ratio: int = 1,
|
||||
enable_ln_post: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enable_ln_post = enable_ln_post
|
||||
self.fourier_embedder = fourier_embedder
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
mlp_expand_ratio=mlp_expand_ratio,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
if self.enable_ln_post:
|
||||
self.ln_post = ops.LayerNorm(width)
|
||||
self.output_proj = ops.Linear(width, out_channels)
|
||||
self.label_type = label_type
|
||||
self.count = 0
|
||||
|
||||
def forward(self, queries=None, query_embeddings=None, latents=None):
|
||||
if query_embeddings is None:
|
||||
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
||||
self.count += query_embeddings.shape[1]
|
||||
if self.downsample_ratio != 1:
|
||||
latents = self.latents_proj(latents)
|
||||
x = self.cross_attn_decoder(query_embeddings, latents)
|
||||
if self.enable_ln_post:
|
||||
x = self.ln_post(x)
|
||||
occ = self.output_proj(x)
|
||||
return occ
|
||||
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=width,
|
||||
layers=num_decoder_layers,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
out_channels=1,
|
||||
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
||||
downsample_ratio=geo_decoder_downsample_ratio,
|
||||
enable_ln_post=self.geo_decoder_ln_post,
|
||||
width=width // geo_decoder_downsample_ratio,
|
||||
heads=heads // geo_decoder_downsample_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
label_type=label_type,
|
||||
)
|
||||
|
||||
self.volume_decoder = VanillaVolumeDecoder()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def decode(self, latents, **kwargs):
|
||||
latents = self.post_kl(latents.movedim(-2, -1))
|
||||
latents = self.transformer(latents)
|
||||
|
||||
bounds = kwargs.get("bounds", 1.01)
|
||||
num_chunks = kwargs.get("num_chunks", 8000)
|
||||
octree_resolution = kwargs.get("octree_resolution", 256)
|
||||
enable_pbar = kwargs.get("enable_pbar", True)
|
||||
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
||||
355
comfy/ldm/hunyuan_video/model.py
Normal file
355
comfy/ldm/hunyuan_video/model.py
Normal file
@@ -0,0 +1,355 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding
|
||||
)
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunyuanVideoParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: list
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
|
||||
class TokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
mlp_hidden_dim = hidden_size * 4
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn.qkv(norm_x)
|
||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads,
|
||||
num_blocks,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TokenRefinerBlock(
|
||||
hidden_size=hidden_size,
|
||||
heads=heads,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
m = None
|
||||
if mask is not None:
|
||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||
m = m + m.transpose(2, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, m)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class TokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_dim,
|
||||
hidden_size,
|
||||
heads,
|
||||
num_blocks,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
mask,
|
||||
):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
return x
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
|
||||
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
ref_latent = self.img_in(ref_latent)
|
||||
img = torch.cat([ref_latent, img], dim=-2)
|
||||
ref_latent_ids[..., 0] = -1
|
||||
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
|
||||
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
if self.params.guidance_embed:
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def img_ids(self, x):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
@@ -159,7 +159,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
||||
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
||||
v = v.transpose(-2, -3).contiguous()
|
||||
v = v.transpose(-2, -3).contiguous()
|
||||
|
||||
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
|
||||
|
||||
@@ -1,24 +1,17 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||
Mlp,
|
||||
TimestepEmbedder,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from .poolers import AttentionPool
|
||||
|
||||
import comfy.latent_formats
|
||||
from .models import HunYuanDiTBlock, calc_rope
|
||||
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
|
||||
class HunYuanControlNet(nn.Module):
|
||||
@@ -171,9 +164,6 @@ class HunYuanControlNet(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
||||
@@ -53,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
|
||||
if norm_type == "layer":
|
||||
norm_layer = operations.LayerNorm
|
||||
elif norm_type == "rms":
|
||||
norm_layer = RMSNorm
|
||||
norm_layer = operations.RMSNorm
|
||||
else:
|
||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||
|
||||
@@ -250,9 +248,6 @@ class HunYuanDiT(nn.Module):
|
||||
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -262,8 +261,8 @@ class CrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
@@ -377,11 +376,16 @@ class LTXVModel(torch.nn.Module):
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.vae_scale_factors = vae_scale_factors
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
@@ -415,39 +419,31 @@ class LTXVModel(torch.nn.Module):
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = 0.0
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
x, latent_coords = self.patchifier.patchify(x)
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords=latent_coords,
|
||||
scale_factors=self.vae_scale_factors,
|
||||
causal_fix=self.causal_temporal_positioning,
|
||||
)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
@@ -507,8 +503,4 @@ class LTXVModel(torch.nn.Module):
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
||||
|
||||
@@ -6,16 +6,29 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
def latent_to_pixel_coords(
|
||||
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
|
||||
) -> Tensor:
|
||||
"""
|
||||
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
||||
configuration.
|
||||
Args:
|
||||
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
||||
containing the latent corner coordinates of each token.
|
||||
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
|
||||
causal_fix (bool): Whether to take into account the different temporal scale
|
||||
of the first frame. Default = False for backwards compatibility.
|
||||
Returns:
|
||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||
"""
|
||||
pixel_coords = (
|
||||
latent_coords
|
||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||
)
|
||||
if causal_fix:
|
||||
# Fix temporal scale for first frame to 1 due to causality
|
||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||
return pixel_coords
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
@@ -44,29 +57,26 @@ class Patchifier(ABC):
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
def get_latent_coords(
|
||||
self, latent_num_frames, latent_height, latent_width, batch_size, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
"""
|
||||
Return a tensor of shape [batch_size, 3, num_patches] containing the
|
||||
top-left corner latent coordinates of each latent patch.
|
||||
The tensor is repeated for each batch element.
|
||||
"""
|
||||
latent_sample_coords = torch.meshgrid(
|
||||
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
|
||||
torch.arange(0, latent_height, self._patch_size[1], device=device),
|
||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_coords = rearrange(
|
||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
return latent_coords
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
@@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
b, _, f, h, w = latents.shape
|
||||
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
@@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
return latents, latent_coords
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
|
||||
@@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
padding_mode=spatial_padding_mode,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Any, Mapping, Optional, Tuple, Union, List
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
@@ -30,7 +34,7 @@ class Encoder(nn.Module):
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -38,12 +42,13 @@ class Encoder(nn.Module):
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
@@ -63,6 +68,7 @@ class Encoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
@@ -80,6 +86,7 @@ class Encoder(nn.Module):
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
@@ -90,6 +97,7 @@ class Encoder(nn.Module):
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
@@ -99,6 +107,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
@@ -108,6 +117,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
@@ -117,6 +127,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
@@ -127,6 +138,34 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(2, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
@@ -150,10 +189,18 @@ class Encoder(nn.Module):
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var == "constant":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
dims,
|
||||
output_channel,
|
||||
conv_out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -195,6 +242,15 @@ class Encoder(nn.Module):
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
elif self.latent_log_var == "constant":
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = (
|
||||
-30
|
||||
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
|
||||
sample = torch.cat(
|
||||
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
@@ -229,13 +285,15 @@ class Decoder(nn.Module):
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
@@ -250,6 +308,8 @@ class Decoder(nn.Module):
|
||||
block_params = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
@@ -259,6 +319,7 @@ class Decoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
@@ -276,6 +337,21 @@ class Decoder(nn.Module):
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
@@ -286,21 +362,33 @@ class Decoder(nn.Module):
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
@@ -318,32 +406,86 @@ class Decoder(nn.Module):
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
dims,
|
||||
output_channel,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(
|
||||
torch.tensor(1000.0, dtype=torch.float32)
|
||||
)
|
||||
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
output_channel * 2, 0, operations=ops,
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
# assert target_shape is not None, "target_shape must be provided"
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
sample = sample.to(upscale_dtype)
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
timestep=scaled_timestep.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=sample.shape[0],
|
||||
hidden_dtype=sample.dtype,
|
||||
)
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, embedded_timestep.shape[-1], 1, 1, 1
|
||||
)
|
||||
ada_values = self.last_scale_shift_table[
|
||||
None, ..., None, None, None
|
||||
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
|
||||
batch_size,
|
||||
2,
|
||||
-1,
|
||||
embedded_timestep.shape[-3],
|
||||
embedded_timestep.shape[-2],
|
||||
embedded_timestep.shape[-1],
|
||||
)
|
||||
shift, scale = ada_values.unbind(dim=1)
|
||||
sample = sample * (1 + scale) + shift
|
||||
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
@@ -363,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
inject_noise (`bool`, *optional*, defaults to `False`):
|
||||
Whether to inject noise into the hidden states.
|
||||
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
||||
Whether to condition the hidden states on the timestep.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
@@ -379,12 +527,22 @@ class UNetMidBlock3D(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
in_channels * 4, 0, operations=ops,
|
||||
)
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock3D(
|
||||
@@ -395,25 +553,105 @@ class UNetMidBlock3D(nn.Module):
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
timestep_embed = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep=timestep.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
timestep_embed = timestep_embed.view(
|
||||
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
||||
)
|
||||
|
||||
for resnet in self.res_blocks:
|
||||
hidden_states = resnet(hidden_states, causal=causal)
|
||||
hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, stride, residual=False):
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = math.prod(stride) * in_channels
|
||||
self.group_size = in_channels * math.prod(stride) // out_channels
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels // math.prod(stride),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.stride[0] == 2:
|
||||
x = torch.cat(
|
||||
[x[:, :, :1, :, :], x], dim=2
|
||||
) # duplicate first frames for padding
|
||||
|
||||
# skip connection
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
|
||||
x_in = x_in.mean(dim=2)
|
||||
|
||||
# conv
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
x = x + x_in
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
in_channels,
|
||||
stride,
|
||||
residual=False,
|
||||
out_channels_reduction_factor=1,
|
||||
spatial_padding_mode="zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = (
|
||||
math.prod(stride) * in_channels // out_channels_reduction_factor
|
||||
)
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
@@ -421,10 +659,12 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
@@ -434,7 +674,8 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
||||
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
@@ -451,11 +692,10 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
@@ -486,11 +726,15 @@ class ResnetBlock3D(nn.Module):
|
||||
groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.inject_noise = inject_noise
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm1 = nn.GroupNorm(
|
||||
@@ -511,8 +755,12 @@ class ResnetBlock3D(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||
@@ -532,8 +780,12 @@ class ResnetBlock3D(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||
|
||||
self.conv_shortcut = (
|
||||
make_linear_nd(
|
||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||
@@ -548,29 +800,84 @@ class ResnetBlock3D(nn.Module):
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
|
||||
def _feed_spatial_noise(
|
||||
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# similar to the "explicit noise inputs" method in style-gan
|
||||
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
||||
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
||||
hidden_states = hidden_states + scaled_noise
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
ada_values = self.scale_shift_table[
|
||||
None, ..., None, None, None
|
||||
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
|
||||
batch_size,
|
||||
4,
|
||||
-1,
|
||||
timestep.shape[-3],
|
||||
timestep.shape[-2],
|
||||
timestep.shape[-1],
|
||||
)
|
||||
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
|
||||
|
||||
hidden_states = hidden_states * (1 + scale1) + shift1
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||
|
||||
if self.inject_noise:
|
||||
hidden_states = self._feed_spatial_noise(
|
||||
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
hidden_states = hidden_states * (1 + scale2) + shift2
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||
|
||||
if self.inject_noise:
|
||||
hidden_states = self._feed_spatial_noise(
|
||||
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
batch_size = input_tensor.shape[0]
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
@@ -634,34 +941,13 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, version=0, config=None):
|
||||
super().__init__()
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
|
||||
if config is None:
|
||||
config = self.guess_config(version)
|
||||
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
@@ -671,28 +957,136 @@ class VideoVAE(nn.Module):
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=self.timestep_conditioning,
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def guess_config(self, version):
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
elif version == 1:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"decoder_blocks": [
|
||||
["res_x", {"num_layers": 5, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 6, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 7, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 8, "inject_noise": False}]
|
||||
],
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_all", {}],
|
||||
["res_x_y", 1],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["compress_all", {}],
|
||||
["res_x_y", 1],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["compress_all", {}],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["res_x", {"num_layers": 4}]
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
"timestep_conditioning": True,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_space_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_time_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}]
|
||||
],
|
||||
"decoder_blocks": [
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}]
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
"timestep_conditioning": True
|
||||
}
|
||||
return config
|
||||
|
||||
def encode(self, x):
|
||||
frames_count = x.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .dual_conv3d import DualConv3d
|
||||
from .causal_conv3d import CausalConv3d
|
||||
@@ -18,7 +17,11 @@ def make_conv_nd(
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
spatial_padding_mode="zeros",
|
||||
temporal_padding_mode="zeros",
|
||||
):
|
||||
if not (spatial_padding_mode == temporal_padding_mode or causal):
|
||||
raise NotImplementedError("spatial and temporal padding modes must be equal")
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
@@ -29,6 +32,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
@@ -41,6 +45,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
@@ -51,6 +56,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
@@ -60,6 +66,7 @@ def make_conv_nd(
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
@@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.padding_mode = padding_mode
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
@@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
@@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
return x
|
||||
@@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
x = F.conv2d(
|
||||
x,
|
||||
weight1,
|
||||
self.bias1,
|
||||
stride1,
|
||||
padding1,
|
||||
dilation1,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
@@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = F.conv1d(
|
||||
x,
|
||||
weight2,
|
||||
self.bias2,
|
||||
stride2,
|
||||
padding2,
|
||||
dilation2,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user