Compare commits
198 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef85058e97 | ||
|
|
f9230bd357 | ||
|
|
537c27cbf3 | ||
|
|
6ff2e4d550 | ||
|
|
222f48c0f2 | ||
|
|
13fd4d6e45 | ||
|
|
1210d094c7 | ||
|
|
255edf2246 | ||
|
|
4f011b9a00 | ||
|
|
67feb05299 | ||
|
|
6d21740346 | ||
|
|
7fbf4b72fe | ||
|
|
14ca5f5a10 | ||
|
|
ce557cfb88 | ||
|
|
96e2a45193 | ||
|
|
dfa2b6d129 | ||
|
|
f3566f0894 | ||
|
|
ca69b41cee | ||
|
|
a058f52090 | ||
|
|
d6bbe8c40f | ||
|
|
a7fe0a94de | ||
|
|
e857dd48b8 | ||
|
|
d303cb5341 | ||
|
|
fb2ad645a3 | ||
|
|
d8a7a32779 | ||
|
|
a00e1489d2 | ||
|
|
ebf038d4fa | ||
|
|
b4de04a1c1 | ||
|
|
b1a02131c9 | ||
|
|
3a3910f91d | ||
|
|
507199d9a8 | ||
|
|
2f3ab40b62 | ||
|
|
7fc3ccdcc2 | ||
|
|
55add50220 | ||
|
|
0aa2368e46 | ||
|
|
cca96a85ae | ||
|
|
619b8cde74 | ||
|
|
31831e6ef1 | ||
|
|
88ceb28e20 | ||
|
|
23289a6a5c | ||
|
|
9d8b6c1f46 | ||
|
|
6320d05696 | ||
|
|
25683b5b02 | ||
|
|
4758fb64b9 | ||
|
|
008761166f | ||
|
|
bfd5dfd611 | ||
|
|
55ade36d01 | ||
|
|
2e20e399ea | ||
|
|
3baf92d120 | ||
|
|
1709a8441e | ||
|
|
cba58fff0b | ||
|
|
2feb8d0b77 | ||
|
|
5b657f8c15 | ||
|
|
2cdbaf5169 | ||
|
|
c78a45685d | ||
|
|
3aaabb12d4 | ||
|
|
1f1c7b7b56 | ||
|
|
90f349f93d | ||
|
|
b9d9bcba14 | ||
|
|
42086af123 | ||
|
|
6c9bd11fa3 | ||
|
|
ee8a7ab69d | ||
|
|
9c773a241b | ||
|
|
adea2beb5c | ||
|
|
2ff3104f70 | ||
|
|
129d8908f7 | ||
|
|
ff838657fa | ||
|
|
2307ff6746 | ||
|
|
d0f3752e33 | ||
|
|
c515bdf371 | ||
|
|
4209edf48d | ||
|
|
d055325783 | ||
|
|
eeab420c70 | ||
|
|
916d1e14a9 | ||
|
|
c496e53519 | ||
|
|
7da85fac3f | ||
|
|
b65b83af6f | ||
|
|
c8a3492c22 | ||
|
|
5cbf79787f | ||
|
|
d45ebb63f6 | ||
|
|
caa6476a69 | ||
|
|
45671cda0b | ||
|
|
8f29664057 | ||
|
|
0b9839ef43 | ||
|
|
953693b137 | ||
|
|
a39ea87bca | ||
|
|
9e9c8a1c64 | ||
|
|
0f11d60afb | ||
|
|
79eea51a1d | ||
|
|
c0338a46a4 | ||
|
|
1c99734e5a | ||
|
|
67758f50f3 | ||
|
|
02eef72bf5 | ||
|
|
b7572b2f87 | ||
|
|
a90aafafc1 | ||
|
|
d9b7cfac7e | ||
|
|
3507870535 | ||
|
|
82ecb02c1e | ||
|
|
a618f768e0 | ||
|
|
e1dec3c792 | ||
|
|
96697c4bc5 | ||
|
|
b504bd606d | ||
|
|
d170292594 | ||
|
|
9cfd185676 | ||
|
|
4b5bcd8ac4 | ||
|
|
ceb50b2cbf | ||
|
|
160ca08138 | ||
|
|
c4bfdba330 | ||
|
|
ee9547ba31 | ||
|
|
19a64d6291 | ||
|
|
b486885e08 | ||
|
|
0229228f3f | ||
|
|
1ed75ab30e | ||
|
|
99a1fb6027 | ||
|
|
73e04987f7 | ||
|
|
5388df784a | ||
|
|
26e0ba8f8c | ||
|
|
bc6dac4327 | ||
|
|
f18ebbd316 | ||
|
|
15564688ed | ||
|
|
c6b9c11ef6 | ||
|
|
e44d0ac7f7 | ||
|
|
56bc64f351 | ||
|
|
f7d83b72e0 | ||
|
|
80f07952d2 | ||
|
|
57f330caf9 | ||
|
|
601ff9e3db | ||
|
|
341667c4d5 | ||
|
|
1419dee915 | ||
|
|
da13b6b827 | ||
|
|
c86cd58573 | ||
|
|
b5fe39211a | ||
|
|
e946667216 | ||
|
|
d7969cb070 | ||
|
|
bddb02660c | ||
|
|
418eb7062d | ||
|
|
cac68ca813 | ||
|
|
52c1d933b2 | ||
|
|
3cacd3fca5 | ||
|
|
2dda7c11a3 | ||
|
|
3ad3248ad7 | ||
|
|
c441048a4f | ||
|
|
9f4b181ab3 | ||
|
|
cbbf077593 | ||
|
|
0c04a6ae78 | ||
|
|
416ccc9e45 | ||
|
|
ff2ff02168 | ||
|
|
4c5c4ddeda | ||
|
|
79badea452 | ||
|
|
37e5390f5f | ||
|
|
a4f59bc65e | ||
|
|
ca457f7ba1 | ||
|
|
cd6f615038 | ||
|
|
517669aaa3 | ||
|
|
e4e1bff605 | ||
|
|
d6656b0c0c | ||
|
|
f4cdedea62 | ||
|
|
39b1fc4ccc | ||
|
|
0b25f47bd9 | ||
|
|
bda1482a27 | ||
|
|
19ee5d9d8b | ||
|
|
61b50720d0 | ||
|
|
0f954f34af | ||
|
|
5262901c5c | ||
|
|
cc550d5908 | ||
|
|
6d1a3f7d00 | ||
|
|
1b3a650f19 | ||
|
|
e83063bf24 | ||
|
|
558b7d8b22 | ||
|
|
caf2074773 | ||
|
|
bdf393792d | ||
|
|
4e14032c02 | ||
|
|
59d58b1158 | ||
|
|
563291ee51 | ||
|
|
6c0377f43e | ||
|
|
2cddbf0821 | ||
|
|
60749f345d | ||
|
|
d4426dce7c | ||
|
|
d9d7f3c619 | ||
|
|
fd5dfb812c | ||
|
|
3dfdddcc91 | ||
|
|
5747bc6457 | ||
|
|
5bea1d2ec9 | ||
|
|
5def9fbc83 | ||
|
|
7a7efe8424 | ||
|
|
44db978531 | ||
|
|
1c8d11e48a | ||
|
|
a220d11e6b | ||
|
|
23827ca312 | ||
|
|
0fd4e6c778 | ||
|
|
e2fafe0686 | ||
|
|
6579632201 | ||
|
|
ac2f0523ca | ||
|
|
fbf68c4e52 | ||
|
|
93477f8efe | ||
|
|
8af9a91e0c | ||
|
|
005d2d3a13 | ||
|
|
1e21f4c14e |
@@ -28,17 +28,17 @@ def pull(repo, remote_name='origin', branch='master'):
|
|||||||
|
|
||||||
if repo.index.conflicts is not None:
|
if repo.index.conflicts is not None:
|
||||||
for conflict in repo.index.conflicts:
|
for conflict in repo.index.conflicts:
|
||||||
print('Conflicts found in:', conflict[0].path)
|
print('Conflicts found in:', conflict[0].path) # noqa: T201
|
||||||
raise AssertionError('Conflicts, ahhhhh!!')
|
raise AssertionError('Conflicts, ahhhhh!!')
|
||||||
|
|
||||||
user = repo.default_signature
|
user = repo.default_signature
|
||||||
tree = repo.index.write_tree()
|
tree = repo.index.write_tree()
|
||||||
commit = repo.create_commit('HEAD',
|
repo.create_commit('HEAD',
|
||||||
user,
|
user,
|
||||||
user,
|
user,
|
||||||
'Merge!',
|
'Merge!',
|
||||||
tree,
|
tree,
|
||||||
[repo.head.target, remote_master_id])
|
[repo.head.target, remote_master_id])
|
||||||
# We need to do this or git CLI will think we are still merging.
|
# We need to do this or git CLI will think we are still merging.
|
||||||
repo.state_cleanup()
|
repo.state_cleanup()
|
||||||
else:
|
else:
|
||||||
@@ -49,18 +49,18 @@ repo_path = str(sys.argv[1])
|
|||||||
repo = pygit2.Repository(repo_path)
|
repo = pygit2.Repository(repo_path)
|
||||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||||
try:
|
try:
|
||||||
print("stashing current changes")
|
print("stashing current changes") # noqa: T201
|
||||||
repo.stash(ident)
|
repo.stash(ident)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("nothing to stash")
|
print("nothing to stash") # noqa: T201
|
||||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||||
print("creating backup branch: {}".format(backup_branch_name))
|
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||||
try:
|
try:
|
||||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
print("checking out master branch")
|
print("checking out master branch") # noqa: T201
|
||||||
branch = repo.lookup_branch('master')
|
branch = repo.lookup_branch('master')
|
||||||
if branch is None:
|
if branch is None:
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||||
@@ -72,7 +72,7 @@ else:
|
|||||||
ref = repo.lookup_reference(branch.name)
|
ref = repo.lookup_reference(branch.name)
|
||||||
repo.checkout(ref)
|
repo.checkout(ref)
|
||||||
|
|
||||||
print("pulling latest changes")
|
print("pulling latest changes") # noqa: T201
|
||||||
pull(repo)
|
pull(repo)
|
||||||
|
|
||||||
if "--stable" in sys.argv:
|
if "--stable" in sys.argv:
|
||||||
@@ -94,7 +94,7 @@ if "--stable" in sys.argv:
|
|||||||
if latest_tag is not None:
|
if latest_tag is not None:
|
||||||
repo.checkout(latest_tag)
|
repo.checkout(latest_tag)
|
||||||
|
|
||||||
print("Done!")
|
print("Done!") # noqa: T201
|
||||||
|
|
||||||
self_update = True
|
self_update = True
|
||||||
if len(sys.argv) > 2:
|
if len(sys.argv) > 2:
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ name: Python Linting
|
|||||||
on: [push, pull_request]
|
on: [push, pull_request]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pylint:
|
ruff:
|
||||||
name: Run Pylint
|
name: Run Ruff
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -16,8 +16,8 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: 3.x
|
python-version: 3.x
|
||||||
|
|
||||||
- name: Install Pylint
|
- name: Install Ruff
|
||||||
run: pip install pylint
|
run: pip install ruff
|
||||||
|
|
||||||
- name: Run Pylint
|
- name: Run Ruff
|
||||||
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
|
run: ruff check .
|
||||||
4
.github/workflows/stable-release.yml
vendored
4
.github/workflows/stable-release.yml
vendored
@@ -12,7 +12,7 @@ on:
|
|||||||
description: 'CUDA version'
|
description: 'CUDA version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "124"
|
default: "126"
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'Python minor version'
|
description: 'Python minor version'
|
||||||
required: true
|
required: true
|
||||||
@@ -22,7 +22,7 @@ on:
|
|||||||
description: 'Python patch version'
|
description: 'Python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "7"
|
default: "8"
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
4
.github/workflows/test-build.yml
vendored
4
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
@@ -28,4 +28,4 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|||||||
53
.github/workflows/test-ci.yml
vendored
53
.github/workflows/test-ci.yml
vendored
@@ -20,7 +20,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos, linux, windows]
|
# os: [macos, linux, windows]
|
||||||
|
os: [macos, linux]
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
cuda_version: ["12.1"]
|
cuda_version: ["12.1"]
|
||||||
torch_version: ["stable"]
|
torch_version: ["stable"]
|
||||||
@@ -31,9 +32,9 @@ jobs:
|
|||||||
- os: linux
|
- os: linux
|
||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
# - os: windows
|
||||||
runner_label: [self-hosted, Windows]
|
# runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
# flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
- name: Test Workflows
|
- name: Test Workflows
|
||||||
@@ -45,28 +46,28 @@ jobs:
|
|||||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
comfyui_flags: ${{ matrix.flags }}
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
|
||||||
test-win-nightly:
|
# test-win-nightly:
|
||||||
strategy:
|
# strategy:
|
||||||
fail-fast: true
|
# fail-fast: true
|
||||||
matrix:
|
# matrix:
|
||||||
os: [windows]
|
# os: [windows]
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
# python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
cuda_version: ["12.1"]
|
# cuda_version: ["12.1"]
|
||||||
torch_version: ["nightly"]
|
# torch_version: ["nightly"]
|
||||||
include:
|
# include:
|
||||||
- os: windows
|
# - os: windows
|
||||||
runner_label: [self-hosted, Windows]
|
# runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
# flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
# runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
# steps:
|
||||||
- name: Test Workflows
|
# - name: Test Workflows
|
||||||
uses: comfy-org/comfy-action@main
|
# uses: comfy-org/comfy-action@main
|
||||||
with:
|
# with:
|
||||||
os: ${{ matrix.os }}
|
# os: ${{ matrix.os }}
|
||||||
python_version: ${{ matrix.python_version }}
|
# python_version: ${{ matrix.python_version }}
|
||||||
torch_version: ${{ matrix.torch_version }}
|
# torch_version: ${{ matrix.torch_version }}
|
||||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
comfyui_flags: ${{ matrix.flags }}
|
# comfyui_flags: ${{ matrix.flags }}
|
||||||
|
|
||||||
test-unix-nightly:
|
test-unix-nightly:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
path: "ComfyUI"
|
path: "ComfyUI"
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.8'
|
python-version: '3.9'
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
|||||||
2
.github/workflows/test-unit.yml
vendored
2
.github/workflows/test-unit.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.12'
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
|||||||
58
.github/workflows/update-frontend.yml
vendored
Normal file
58
.github/workflows/update-frontend.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
name: Update Frontend Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: "Frontend version to update to (e.g., 1.0.0)"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-frontend:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout ComfyUI
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install wait-for-it
|
||||||
|
# Frontend asset will be downloaded to ComfyUI/web_custom_versions/Comfy-Org_ComfyUI_frontend/{version}
|
||||||
|
- name: Start ComfyUI server
|
||||||
|
run: |
|
||||||
|
python main.py --cpu --front-end-version Comfy-Org/ComfyUI_frontend@${{ github.event.inputs.version }} 2>&1 | tee console_output.log &
|
||||||
|
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||||
|
- name: Configure Git
|
||||||
|
run: |
|
||||||
|
git config --global user.name "GitHub Action"
|
||||||
|
git config --global user.email "action@github.com"
|
||||||
|
# Replace existing frontend content with the new version and remove .js.map files
|
||||||
|
# See https://github.com/Comfy-Org/ComfyUI_frontend/issues/2145 for why we remove .js.map files
|
||||||
|
- name: Update frontend content
|
||||||
|
run: |
|
||||||
|
rm -rf web/
|
||||||
|
cp -r web_custom_versions/Comfy-Org_ComfyUI_frontend/${{ github.event.inputs.version }} web/
|
||||||
|
rm web/**/*.js.map
|
||||||
|
- name: Create Pull Request
|
||||||
|
uses: peter-evans/create-pull-request@v7
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.PR_BOT_PAT }}
|
||||||
|
commit-message: "Update frontend to v${{ github.event.inputs.version }}"
|
||||||
|
title: "Frontend Update: v${{ github.event.inputs.version }}"
|
||||||
|
body: |
|
||||||
|
Automated PR to update frontend content to version ${{ github.event.inputs.version }}
|
||||||
|
|
||||||
|
This PR was created automatically by the frontend update workflow.
|
||||||
|
branch: release-${{ github.event.inputs.version }}
|
||||||
|
base: master
|
||||||
|
labels: Frontend,dependencies
|
||||||
58
.github/workflows/update-version.yml
vendored
Normal file
58
.github/workflows/update-version.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
name: Update Version File
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "pyproject.toml"
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-version:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
# Don't run on fork PRs
|
||||||
|
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
permissions:
|
||||||
|
pull-requests: write
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
|
||||||
|
- name: Update comfyui_version.py
|
||||||
|
run: |
|
||||||
|
# Read version from pyproject.toml and update comfyui_version.py
|
||||||
|
python -c '
|
||||||
|
import tomllib
|
||||||
|
|
||||||
|
# Read version from pyproject.toml
|
||||||
|
with open("pyproject.toml", "rb") as f:
|
||||||
|
config = tomllib.load(f)
|
||||||
|
version = config["project"]["version"]
|
||||||
|
|
||||||
|
# Write version to comfyui_version.py
|
||||||
|
with open("comfyui_version.py", "w") as f:
|
||||||
|
f.write("# This file is automatically generated by the build process when version is\n")
|
||||||
|
f.write("# updated in pyproject.toml.\n")
|
||||||
|
f.write(f"__version__ = \"{version}\"\n")
|
||||||
|
'
|
||||||
|
|
||||||
|
- name: Commit changes
|
||||||
|
run: |
|
||||||
|
git config --local user.name "github-actions"
|
||||||
|
git config --local user.email "github-actions@github.com"
|
||||||
|
git fetch origin ${{ github.head_ref }}
|
||||||
|
git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
|
||||||
|
git add comfyui_version.py
|
||||||
|
git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
|
||||||
|
git push origin HEAD:${{ github.head_ref }}
|
||||||
@@ -17,7 +17,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "124"
|
default: "126"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@@ -29,7 +29,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "7"
|
default: "8"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -7,19 +7,19 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "124"
|
default: "126"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "12"
|
default: "13"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "4"
|
default: "1"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "124"
|
default: "126"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@@ -19,7 +19,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "7"
|
default: "8"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
25
CODEOWNERS
25
CODEOWNERS
@@ -1 +1,24 @@
|
|||||||
* @comfyanonymous
|
# Admins
|
||||||
|
* @comfyanonymous
|
||||||
|
|
||||||
|
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
||||||
|
# Inlined the team members for now.
|
||||||
|
|
||||||
|
# Maintainers
|
||||||
|
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
|
||||||
|
# Python web server
|
||||||
|
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
|
||||||
|
# Frontend assets
|
||||||
|
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
|
||||||
|
|
||||||
|
# Extra nodes
|
||||||
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
|||||||
77
README.md
77
README.md
@@ -38,10 +38,22 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- Image Models
|
||||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
- SD1.x, SD2.x,
|
||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||||
|
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||||
|
- Pixart Alpha and Sigma
|
||||||
|
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||||
|
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||||
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
|
- Video Models
|
||||||
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
|
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||||
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
|
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||||
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||||
@@ -61,9 +73,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
|
||||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
|
||||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
|
||||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||||
- Starts up very fast.
|
- Starts up very fast.
|
||||||
- Works fully offline: will never download anything.
|
- Works fully offline: will never download anything.
|
||||||
@@ -101,6 +110,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| `Q` | Toggle visibility of the queue |
|
| `Q` | Toggle visibility of the queue |
|
||||||
| `H` | Toggle visibility of history |
|
| `H` | Toggle visibility of history |
|
||||||
| `R` | Refresh graph |
|
| `R` | Refresh graph |
|
||||||
|
| `F` | Show/Hide menu |
|
||||||
|
| `.` | Fit view to selection (Whole graph when nothing is selected) |
|
||||||
| Double-Click LMB | Open node quick search palette |
|
| Double-Click LMB | Open node quick search palette |
|
||||||
| `Shift` + Drag | Move multiple wires at once |
|
| `Shift` + Drag | Move multiple wires at once |
|
||||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||||
@@ -143,9 +154,33 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
|||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
|
||||||
|
|
||||||
|
### Intel GPUs (Windows and Linux)
|
||||||
|
|
||||||
|
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||||
|
|
||||||
|
1. To install PyTorch nightly, use the following command:
|
||||||
|
|
||||||
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||||
|
|
||||||
|
2. Launch ComfyUI by running `python main.py`
|
||||||
|
|
||||||
|
|
||||||
|
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||||
|
|
||||||
|
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||||
|
|
||||||
|
```
|
||||||
|
conda install libuv
|
||||||
|
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||||
|
```
|
||||||
|
|
||||||
|
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||||
|
|
||||||
|
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
@@ -155,7 +190,7 @@ Nvidia users should install stable pytorch using this command:
|
|||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
||||||
|
|
||||||
#### Troubleshooting
|
#### Troubleshooting
|
||||||
|
|
||||||
@@ -175,17 +210,6 @@ After this you should have everything installed and can proceed to running Comfy
|
|||||||
|
|
||||||
### Others:
|
### Others:
|
||||||
|
|
||||||
#### Intel GPUs
|
|
||||||
|
|
||||||
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
|
|
||||||
|
|
||||||
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
|
|
||||||
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
|
|
||||||
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
|
|
||||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
|
|
||||||
|
|
||||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
|
||||||
|
|
||||||
#### Apple Mac silicon
|
#### Apple Mac silicon
|
||||||
|
|
||||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||||
@@ -201,6 +225,16 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
|||||||
|
|
||||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||||
|
|
||||||
|
#### Ascend NPUs
|
||||||
|
|
||||||
|
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
||||||
|
|
||||||
|
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
|
||||||
|
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
|
||||||
|
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
|
||||||
|
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
|
||||||
|
|
||||||
|
|
||||||
# Running
|
# Running
|
||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
@@ -306,4 +340,3 @@ This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy
|
|||||||
### Which GPU should I buy for this?
|
### Which GPU should I buy for this?
|
||||||
|
|
||||||
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
|
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class InternalRoutes:
|
|||||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||||
|
|
||||||
@self.routes.get('/logs/raw')
|
@self.routes.get('/logs/raw')
|
||||||
async def get_logs(request):
|
async def get_raw_logs(request):
|
||||||
self.terminal_service.update_size()
|
self.terminal_service.update_size()
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"entries": list(app.logger.get_logs()),
|
"entries": list(app.logger.get_logs()),
|
||||||
|
|||||||
@@ -10,4 +10,4 @@ class FileService:
|
|||||||
if directory_key not in self.allowed_directories:
|
if directory_key not in self.allowed_directories:
|
||||||
raise ValueError("Invalid directory key")
|
raise ValueError("Invalid directory key")
|
||||||
directory_path: str = self.allowed_directories[directory_key]
|
directory_path: str = self.allowed_directories[directory_key]
|
||||||
return self.file_system_ops.walk_directory(directory_path)
|
return self.file_system_ops.walk_directory(directory_path)
|
||||||
|
|||||||
@@ -25,10 +25,10 @@ class TerminalService:
|
|||||||
def update_size(self):
|
def update_size(self):
|
||||||
columns, lines = self.get_terminal_size()
|
columns, lines = self.get_terminal_size()
|
||||||
changed = False
|
changed = False
|
||||||
|
|
||||||
if columns != self.cols:
|
if columns != self.cols:
|
||||||
self.cols = columns
|
self.cols = columns
|
||||||
changed = True
|
changed = True
|
||||||
|
|
||||||
if lines != self.rows:
|
if lines != self.rows:
|
||||||
self.rows = lines
|
self.rows = lines
|
||||||
@@ -48,9 +48,9 @@ class TerminalService:
|
|||||||
def send_messages(self, entries):
|
def send_messages(self, entries):
|
||||||
if not len(entries) or not len(self.subscriptions):
|
if not len(entries) or not len(self.subscriptions):
|
||||||
return
|
return
|
||||||
|
|
||||||
new_size = self.update_size()
|
new_size = self.update_size()
|
||||||
|
|
||||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||||
if client_id not in self.server.sockets:
|
if client_id not in self.server.sockets:
|
||||||
# Automatically unsub if the socket has disconnected
|
# Automatically unsub if the socket has disconnected
|
||||||
|
|||||||
@@ -39,4 +39,4 @@ class FileSystemOperations:
|
|||||||
"path": relative_path,
|
"path": relative_path,
|
||||||
"type": "directory"
|
"type": "directory"
|
||||||
})
|
})
|
||||||
return file_list
|
return file_list
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class AppSettings():
|
class AppSettings():
|
||||||
@@ -11,8 +12,12 @@ class AppSettings():
|
|||||||
file = self.user_manager.get_request_user_filepath(
|
file = self.user_manager.get_request_user_filepath(
|
||||||
request, "comfy.settings.json")
|
request, "comfy.settings.json")
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
with open(file) as f:
|
try:
|
||||||
return json.load(f)
|
with open(file) as f:
|
||||||
|
return json.load(f)
|
||||||
|
except:
|
||||||
|
logging.error(f"The user settings file is corrupted: {file}")
|
||||||
|
return {}
|
||||||
else:
|
else:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -51,4 +56,4 @@ class AppSettings():
|
|||||||
settings = self.get_settings(request)
|
settings = self.get_settings(request)
|
||||||
settings[setting_id] = await request.json()
|
settings[setting_id] = await request.json()
|
||||||
self.save_settings(request, settings)
|
self.save_settings(request, settings)
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|||||||
134
app/custom_node_manager.py
Normal file
134
app/custom_node_manager.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import folder_paths
|
||||||
|
import glob
|
||||||
|
from aiohttp import web
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from utils.json_util import merge_json_recursive
|
||||||
|
|
||||||
|
|
||||||
|
# Extra locale files to load into main.json
|
||||||
|
EXTRA_LOCALE_FILES = [
|
||||||
|
"nodeDefs.json",
|
||||||
|
"commands.json",
|
||||||
|
"settings.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def safe_load_json_file(file_path: str) -> dict:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error(f"Error loading {file_path}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class CustomNodeManager:
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def build_translations(self):
|
||||||
|
"""Load all custom nodes translations during initialization. Translations are
|
||||||
|
expected to be loaded from `locales/` folder.
|
||||||
|
|
||||||
|
The folder structure is expected to be the following:
|
||||||
|
- custom_nodes/
|
||||||
|
- custom_node_1/
|
||||||
|
- locales/
|
||||||
|
- en/
|
||||||
|
- main.json
|
||||||
|
- commands.json
|
||||||
|
- settings.json
|
||||||
|
|
||||||
|
returned translations are expected to be in the following format:
|
||||||
|
{
|
||||||
|
"en": {
|
||||||
|
"nodeDefs": {...},
|
||||||
|
"commands": {...},
|
||||||
|
"settings": {...},
|
||||||
|
...{other main.json keys}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
translations = {}
|
||||||
|
|
||||||
|
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||||
|
# Sort glob results for deterministic ordering
|
||||||
|
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
|
||||||
|
locales_dir = os.path.join(custom_node_dir, "locales")
|
||||||
|
if not os.path.exists(locales_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
|
||||||
|
lang_code = os.path.basename(os.path.dirname(lang_dir))
|
||||||
|
|
||||||
|
if lang_code not in translations:
|
||||||
|
translations[lang_code] = {}
|
||||||
|
|
||||||
|
# Load main.json
|
||||||
|
main_file = os.path.join(lang_dir, "main.json")
|
||||||
|
node_translations = safe_load_json_file(main_file)
|
||||||
|
|
||||||
|
# Load extra locale files
|
||||||
|
for extra_file in EXTRA_LOCALE_FILES:
|
||||||
|
extra_file_path = os.path.join(lang_dir, extra_file)
|
||||||
|
key = extra_file.split(".")[0]
|
||||||
|
json_data = safe_load_json_file(extra_file_path)
|
||||||
|
if json_data:
|
||||||
|
node_translations[key] = json_data
|
||||||
|
|
||||||
|
if node_translations:
|
||||||
|
translations[lang_code] = merge_json_recursive(
|
||||||
|
translations[lang_code], node_translations
|
||||||
|
)
|
||||||
|
|
||||||
|
return translations
|
||||||
|
|
||||||
|
def add_routes(self, routes, webapp, loadedModules):
|
||||||
|
|
||||||
|
@routes.get("/workflow_templates")
|
||||||
|
async def get_workflow_templates(request):
|
||||||
|
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
||||||
|
files = [
|
||||||
|
file
|
||||||
|
for folder in folder_paths.get_folder_paths("custom_nodes")
|
||||||
|
for file in glob.glob(
|
||||||
|
os.path.join(folder, "*/example_workflows/*.json")
|
||||||
|
)
|
||||||
|
]
|
||||||
|
workflow_templates_dict = (
|
||||||
|
{}
|
||||||
|
) # custom_nodes folder name -> example workflow names
|
||||||
|
for file in files:
|
||||||
|
custom_nodes_name = os.path.basename(
|
||||||
|
os.path.dirname(os.path.dirname(file))
|
||||||
|
)
|
||||||
|
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
||||||
|
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
|
||||||
|
workflow_name
|
||||||
|
)
|
||||||
|
return web.json_response(workflow_templates_dict)
|
||||||
|
|
||||||
|
# Serve workflow templates from custom nodes.
|
||||||
|
for module_name, module_dir in loadedModules:
|
||||||
|
workflows_dir = os.path.join(module_dir, "example_workflows")
|
||||||
|
if os.path.exists(workflows_dir):
|
||||||
|
webapp.add_routes(
|
||||||
|
[
|
||||||
|
web.static(
|
||||||
|
"/api/workflow_templates/" + module_name, workflows_dir
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@routes.get("/i18n")
|
||||||
|
async def get_i18n(request):
|
||||||
|
"""Returns translations from all custom nodes' locales folders."""
|
||||||
|
return web.json_response(self.build_translations())
|
||||||
@@ -51,7 +51,7 @@ def on_flush(callback):
|
|||||||
if stderr_interceptor is not None:
|
if stderr_interceptor is not None:
|
||||||
stderr_interceptor.on_flush(callback)
|
stderr_interceptor.on_flush(callback)
|
||||||
|
|
||||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
|
||||||
global logs
|
global logs
|
||||||
if logs:
|
if logs:
|
||||||
return
|
return
|
||||||
@@ -70,4 +70,15 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
|||||||
|
|
||||||
stream_handler = logging.StreamHandler()
|
stream_handler = logging.StreamHandler()
|
||||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
|
||||||
|
if use_stdout:
|
||||||
|
# Only errors and critical to stderr
|
||||||
|
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
|
||||||
|
|
||||||
|
# Lesser to stdout
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|||||||
184
app/model_manager.py
Normal file
184
app/model_manager.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import folder_paths
|
||||||
|
import glob
|
||||||
|
import comfy.utils
|
||||||
|
from aiohttp import web
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFileManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||||
|
|
||||||
|
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||||
|
return self.cache.get(key, default)
|
||||||
|
|
||||||
|
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
def add_routes(self, routes):
|
||||||
|
# NOTE: This is an experiment to replace `/models`
|
||||||
|
@routes.get("/experiment/models")
|
||||||
|
async def get_model_folders(request):
|
||||||
|
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||||
|
folder_black_list = ["configs", "custom_nodes"]
|
||||||
|
output_folders: list[dict] = []
|
||||||
|
for folder in model_types:
|
||||||
|
if folder in folder_black_list:
|
||||||
|
continue
|
||||||
|
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
||||||
|
return web.json_response(output_folders)
|
||||||
|
|
||||||
|
# NOTE: This is an experiment to replace `/models/{folder}`
|
||||||
|
@routes.get("/experiment/models/{folder}")
|
||||||
|
async def get_all_models(request):
|
||||||
|
folder = request.match_info.get("folder", None)
|
||||||
|
if not folder in folder_paths.folder_names_and_paths:
|
||||||
|
return web.Response(status=404)
|
||||||
|
files = self.get_model_file_list(folder)
|
||||||
|
return web.json_response(files)
|
||||||
|
|
||||||
|
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||||
|
async def get_model_preview(request):
|
||||||
|
folder_name = request.match_info.get("folder", None)
|
||||||
|
path_index = int(request.match_info.get("path_index", None))
|
||||||
|
filename = request.match_info.get("filename", None)
|
||||||
|
|
||||||
|
if not folder_name in folder_paths.folder_names_and_paths:
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||||
|
folder = folders[0][path_index]
|
||||||
|
full_filename = os.path.join(folder, filename)
|
||||||
|
|
||||||
|
previews = self.get_model_previews(full_filename)
|
||||||
|
default_preview = previews[0] if len(previews) > 0 else None
|
||||||
|
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Image.open(default_preview) as img:
|
||||||
|
img_bytes = BytesIO()
|
||||||
|
img.save(img_bytes, format="WEBP")
|
||||||
|
img_bytes.seek(0)
|
||||||
|
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
||||||
|
except:
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
def get_model_file_list(self, folder_name: str):
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||||
|
output_list: list[dict] = []
|
||||||
|
|
||||||
|
for index, folder in enumerate(folders[0]):
|
||||||
|
if not os.path.isdir(folder):
|
||||||
|
continue
|
||||||
|
out = self.cache_model_file_list_(folder)
|
||||||
|
if out is None:
|
||||||
|
out = self.recursive_search_models_(folder, index)
|
||||||
|
self.set_cache(folder, out)
|
||||||
|
output_list.extend(out[0])
|
||||||
|
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
def cache_model_file_list_(self, folder: str):
|
||||||
|
model_file_list_cache = self.get_cache(folder)
|
||||||
|
|
||||||
|
if model_file_list_cache is None:
|
||||||
|
return None
|
||||||
|
if not os.path.isdir(folder):
|
||||||
|
return None
|
||||||
|
if os.path.getmtime(folder) != model_file_list_cache[1]:
|
||||||
|
return None
|
||||||
|
for x in model_file_list_cache[1]:
|
||||||
|
time_modified = model_file_list_cache[1][x]
|
||||||
|
folder = x
|
||||||
|
if os.path.getmtime(folder) != time_modified:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return model_file_list_cache
|
||||||
|
|
||||||
|
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
||||||
|
if not os.path.isdir(directory):
|
||||||
|
return [], {}, time.perf_counter()
|
||||||
|
|
||||||
|
excluded_dir_names = [".git"]
|
||||||
|
# TODO use settings
|
||||||
|
include_hidden_files = False
|
||||||
|
|
||||||
|
result: list[str] = []
|
||||||
|
dirs: dict[str, float] = {}
|
||||||
|
|
||||||
|
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||||
|
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||||
|
if not include_hidden_files:
|
||||||
|
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
||||||
|
filenames = [f for f in filenames if not f.startswith(".")]
|
||||||
|
|
||||||
|
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
||||||
|
|
||||||
|
for file_name in filenames:
|
||||||
|
try:
|
||||||
|
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||||
|
result.append(relative_path)
|
||||||
|
except:
|
||||||
|
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for d in subdirs:
|
||||||
|
path: str = os.path.join(dirpath, d)
|
||||||
|
try:
|
||||||
|
dirs[path] = os.path.getmtime(path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||||
|
|
||||||
|
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||||
|
dirname = os.path.dirname(filepath)
|
||||||
|
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
return []
|
||||||
|
|
||||||
|
basename = os.path.splitext(filepath)[0]
|
||||||
|
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||||
|
image_files = filter_files_content_types(match_files, "image")
|
||||||
|
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
||||||
|
safetensors_metadata = {}
|
||||||
|
|
||||||
|
result: list[str | BytesIO] = []
|
||||||
|
|
||||||
|
for filename in image_files:
|
||||||
|
_basename = os.path.splitext(filename)[0]
|
||||||
|
if _basename == basename:
|
||||||
|
result.append(filename)
|
||||||
|
if _basename == f"{basename}.preview":
|
||||||
|
result.append(filename)
|
||||||
|
|
||||||
|
if safetensors_file:
|
||||||
|
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||||
|
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
||||||
|
if header:
|
||||||
|
safetensors_metadata = json.loads(header)
|
||||||
|
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
||||||
|
if safetensors_images:
|
||||||
|
safetensors_images = json.loads(safetensors_images)
|
||||||
|
for image in safetensors_images:
|
||||||
|
result.append(BytesIO(base64.b64decode(image)))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.clear_cache()
|
||||||
@@ -38,8 +38,8 @@ class UserManager():
|
|||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
os.makedirs(user_directory, exist_ok=True)
|
os.makedirs(user_directory, exist_ok=True)
|
||||||
if not args.multi_user:
|
if not args.multi_user:
|
||||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(self.get_users_file()):
|
if os.path.isfile(self.get_users_file()):
|
||||||
|
|||||||
@@ -2,11 +2,9 @@
|
|||||||
#and modified
|
#and modified
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch as th
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..ldm.modules.diffusionmodules.util import (
|
from ..ldm.modules.diffusionmodules.util import (
|
||||||
zero_module,
|
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -162,7 +160,6 @@ class ControlNet(nn.Module):
|
|||||||
if isinstance(self.num_classes, int):
|
if isinstance(self.num_classes, int):
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
print("setting up linear c_adm embedding layer")
|
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
elif self.num_classes == "sequential":
|
elif self.num_classes == "sequential":
|
||||||
assert adm_in_channels is not None
|
assert adm_in_channels is not None
|
||||||
@@ -415,7 +412,6 @@ class ControlNet(nn.Module):
|
|||||||
out_output = []
|
out_output = []
|
||||||
out_middle = []
|
out_middle = []
|
||||||
|
|
||||||
hs = []
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import rearrange
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||||
|
|
||||||
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||||
|
|||||||
@@ -43,10 +43,11 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific
|
|||||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||||
|
|
||||||
|
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
|
||||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
|
||||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
|
||||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
|
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
@@ -84,7 +85,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
|
|||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||||
|
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||||
|
|
||||||
class LatentPreviewMethod(enum.Enum):
|
class LatentPreviewMethod(enum.Enum):
|
||||||
NoPreviews = "none"
|
NoPreviews = "none"
|
||||||
@@ -104,6 +106,7 @@ attn_group = parser.add_mutually_exclusive_group()
|
|||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
@@ -120,7 +123,7 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
|||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
@@ -139,6 +142,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
|||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
|
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
@@ -173,7 +177,7 @@ parser.add_argument(
|
|||||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ This module provides type hinting and concrete convenience types for node develo
|
|||||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||||
|
|
||||||
class ExampleNode(ComfyNodeABC):
|
class ExampleNode(ComfyNodeABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
|
|
||||||
|
|
||||||
class ExampleNode(ComfyNodeABC):
|
class ExampleNode(ComfyNodeABC):
|
||||||
"""An example node that just adds 1 to an input integer.
|
"""An example node that just adds 1 to an input integer.
|
||||||
|
|
||||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
|
||||||
* Not intended for use in ComfyUI.
|
* This node is intended as an example for developers only.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DESCRIPTION = cleandoc(__doc__)
|
DESCRIPTION = cleandoc(__doc__)
|
||||||
|
|||||||
@@ -3,9 +3,6 @@ import math
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
|
||||||
return abs(a*b) // math.gcd(a, b)
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
@@ -46,7 +43,7 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||||
return False
|
return False
|
||||||
|
|
||||||
mult_min = lcm(s1[1], s2[1])
|
mult_min = math.lcm(s1[1], s2[1])
|
||||||
diff = mult_min // min(s1[1], s2[1])
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
return False
|
return False
|
||||||
@@ -57,7 +54,7 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
crossattn_max_len = self.cond.shape[1]
|
crossattn_max_len = self.cond.shape[1]
|
||||||
for x in others:
|
for x in others:
|
||||||
c = x.cond
|
c = x.cond
|
||||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
|
||||||
conds.append(c)
|
conds.append(c)
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class ControlBase:
|
|||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
out += self.previous_controlnet.get_models()
|
out += self.previous_controlnet.get_models()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_extra_hooks(self):
|
def get_extra_hooks(self):
|
||||||
out = []
|
out = []
|
||||||
if self.extra_hooks is not None:
|
if self.extra_hooks is not None:
|
||||||
@@ -297,7 +297,6 @@ class ControlLoraOps:
|
|||||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
device=None, dtype=None) -> None:
|
device=None, dtype=None) -> None:
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
@@ -382,7 +381,6 @@ class ControlLora(ControlNet):
|
|||||||
self.control_model.to(comfy.model_management.get_torch_device())
|
self.control_model.to(comfy.model_management.get_torch_device())
|
||||||
diffusion_model = model.diffusion_model
|
diffusion_model = model.diffusion_model
|
||||||
sd = diffusion_model.state_dict()
|
sd = diffusion_model.state_dict()
|
||||||
cm = self.control_model.state_dict()
|
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
weight = sd[k]
|
weight = sd[k]
|
||||||
@@ -823,7 +821,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
|||||||
for i in range(4):
|
for i in range(4):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
|
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
|
||||||
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
|
prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
|
||||||
prefix_replace["adapter."] = ""
|
prefix_replace["adapter."] = ""
|
||||||
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
|
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
|
||||||
keys = t2i_data.keys()
|
keys = t2i_data.keys()
|
||||||
|
|||||||
@@ -4,105 +4,6 @@ import logging
|
|||||||
|
|
||||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||||
|
|
||||||
# =================#
|
|
||||||
# UNet Conversion #
|
|
||||||
# =================#
|
|
||||||
|
|
||||||
unet_conversion_map = [
|
|
||||||
# (stable-diffusion, HF Diffusers)
|
|
||||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
|
||||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
|
||||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
|
||||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
|
||||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
|
||||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
|
||||||
("out.0.weight", "conv_norm_out.weight"),
|
|
||||||
("out.0.bias", "conv_norm_out.bias"),
|
|
||||||
("out.2.weight", "conv_out.weight"),
|
|
||||||
("out.2.bias", "conv_out.bias"),
|
|
||||||
]
|
|
||||||
|
|
||||||
unet_conversion_map_resnet = [
|
|
||||||
# (stable-diffusion, HF Diffusers)
|
|
||||||
("in_layers.0", "norm1"),
|
|
||||||
("in_layers.2", "conv1"),
|
|
||||||
("out_layers.0", "norm2"),
|
|
||||||
("out_layers.3", "conv2"),
|
|
||||||
("emb_layers.1", "time_emb_proj"),
|
|
||||||
("skip_connection", "conv_shortcut"),
|
|
||||||
]
|
|
||||||
|
|
||||||
unet_conversion_map_layer = []
|
|
||||||
# hardcoded number of downblocks and resnets/attentions...
|
|
||||||
# would need smarter logic for other networks.
|
|
||||||
for i in range(4):
|
|
||||||
# loop over downblocks/upblocks
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
# loop over resnets/attentions for downblocks
|
|
||||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
|
||||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
|
||||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
|
||||||
|
|
||||||
if i < 3:
|
|
||||||
# no attention layers in down_blocks.3
|
|
||||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
|
||||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
|
||||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
|
||||||
|
|
||||||
for j in range(3):
|
|
||||||
# loop over resnets/attentions for upblocks
|
|
||||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
|
||||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
|
||||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
|
||||||
|
|
||||||
if i > 0:
|
|
||||||
# no attention layers in up_blocks.0
|
|
||||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
|
||||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
|
||||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
|
||||||
|
|
||||||
if i < 3:
|
|
||||||
# no downsample in down_blocks.3
|
|
||||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
|
||||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
|
||||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
|
||||||
|
|
||||||
# no upsample in up_blocks.3
|
|
||||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
|
||||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
|
||||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
|
||||||
|
|
||||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
|
||||||
sd_mid_atn_prefix = "middle_block.1."
|
|
||||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
|
||||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
|
||||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
|
||||||
|
|
||||||
|
|
||||||
def convert_unet_state_dict(unet_state_dict):
|
|
||||||
# buyer beware: this is a *brittle* function,
|
|
||||||
# and correct output requires that all of these pieces interact in
|
|
||||||
# the exact order in which I have arranged them.
|
|
||||||
mapping = {k: k for k in unet_state_dict.keys()}
|
|
||||||
for sd_name, hf_name in unet_conversion_map:
|
|
||||||
mapping[hf_name] = sd_name
|
|
||||||
for k, v in mapping.items():
|
|
||||||
if "resnets" in k:
|
|
||||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
|
||||||
v = v.replace(hf_part, sd_part)
|
|
||||||
mapping[k] = v
|
|
||||||
for k, v in mapping.items():
|
|
||||||
for sd_part, hf_part in unet_conversion_map_layer:
|
|
||||||
v = v.replace(hf_part, sd_part)
|
|
||||||
mapping[k] = v
|
|
||||||
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
|
||||||
return new_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
# ================#
|
# ================#
|
||||||
# VAE Conversion #
|
# VAE Conversion #
|
||||||
# ================#
|
# ================#
|
||||||
@@ -157,16 +58,23 @@ vae_conversion_map_attn = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def reshape_weight_for_sd(w):
|
def reshape_weight_for_sd(w, conv3d=False):
|
||||||
# convert HF linear weights to SD conv2d weights
|
# convert HF linear weights to SD conv2d weights
|
||||||
return w.reshape(*w.shape, 1, 1)
|
if conv3d:
|
||||||
|
return w.reshape(*w.shape, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
return w.reshape(*w.shape, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
def convert_vae_state_dict(vae_state_dict):
|
def convert_vae_state_dict(vae_state_dict):
|
||||||
mapping = {k: k for k in vae_state_dict.keys()}
|
mapping = {k: k for k in vae_state_dict.keys()}
|
||||||
|
conv3d = False
|
||||||
for k, v in mapping.items():
|
for k, v in mapping.items():
|
||||||
for sd_part, hf_part in vae_conversion_map:
|
for sd_part, hf_part in vae_conversion_map:
|
||||||
v = v.replace(hf_part, sd_part)
|
v = v.replace(hf_part, sd_part)
|
||||||
|
if v.endswith(".conv.weight"):
|
||||||
|
if not conv3d and vae_state_dict[k].ndim == 5:
|
||||||
|
conv3d = True
|
||||||
mapping[k] = v
|
mapping[k] = v
|
||||||
for k, v in mapping.items():
|
for k, v in mapping.items():
|
||||||
if "attentions" in k:
|
if "attentions" in k:
|
||||||
@@ -179,7 +87,7 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
for weight_name in weights_to_convert:
|
for weight_name in weights_to_convert:
|
||||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||||
logging.debug(f"Reshaping {k} for SD format")
|
logging.debug(f"Reshaping {k} for SD format")
|
||||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -206,6 +114,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
|
||||||
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||||
def cat_tensors(tensors):
|
def cat_tensors(tensors):
|
||||||
x = 0
|
x = 0
|
||||||
@@ -222,6 +131,7 @@ def cat_tensors(tensors):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
capture_qkv_weight = {}
|
capture_qkv_weight = {}
|
||||||
@@ -277,5 +187,3 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|||||||
|
|
||||||
def convert_text_enc_state_dict(text_enc_dict):
|
def convert_text_enc_state_dict(text_enc_dict):
|
||||||
return text_enc_dict
|
return text_enc_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
#code taken from: https://github.com/wl-zhao/UniPC and modified
|
#code taken from: https://github.com/wl-zhao/UniPC and modified
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
from tqdm.auto import trange, tqdm
|
from tqdm.auto import trange
|
||||||
|
|
||||||
|
|
||||||
class NoiseScheduleVP:
|
class NoiseScheduleVP:
|
||||||
@@ -80,7 +80,7 @@ class NoiseScheduleVP:
|
|||||||
'linear' or 'cosine' for continuous-time DPMs.
|
'linear' or 'cosine' for continuous-time DPMs.
|
||||||
Returns:
|
Returns:
|
||||||
A wrapper object of the forward SDE (VP type).
|
A wrapper object of the forward SDE (VP type).
|
||||||
|
|
||||||
===============================================================
|
===============================================================
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -208,7 +208,7 @@ def model_wrapper(
|
|||||||
arXiv preprint arXiv:2202.00512 (2022).
|
arXiv preprint arXiv:2202.00512 (2022).
|
||||||
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||||
arXiv preprint arXiv:2210.02303 (2022).
|
arXiv preprint arXiv:2210.02303 (2022).
|
||||||
|
|
||||||
4. "score": marginal score function. (Trained by denoising score matching).
|
4. "score": marginal score function. (Trained by denoising score matching).
|
||||||
Note that the score function and the noise prediction model follows a simple relationship:
|
Note that the score function and the noise prediction model follows a simple relationship:
|
||||||
```
|
```
|
||||||
@@ -226,7 +226,7 @@ def model_wrapper(
|
|||||||
The input `model` has the following format:
|
The input `model` has the following format:
|
||||||
``
|
``
|
||||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||||
``
|
``
|
||||||
|
|
||||||
The input `classifier_fn` has the following format:
|
The input `classifier_fn` has the following format:
|
||||||
``
|
``
|
||||||
@@ -240,12 +240,12 @@ def model_wrapper(
|
|||||||
The input `model` has the following format:
|
The input `model` has the following format:
|
||||||
``
|
``
|
||||||
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
||||||
``
|
``
|
||||||
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
||||||
|
|
||||||
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||||
arXiv preprint arXiv:2207.12598 (2022).
|
arXiv preprint arXiv:2207.12598 (2022).
|
||||||
|
|
||||||
|
|
||||||
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||||
or continuous-time labels (i.e. epsilon to T).
|
or continuous-time labels (i.e. epsilon to T).
|
||||||
@@ -254,7 +254,7 @@ def model_wrapper(
|
|||||||
``
|
``
|
||||||
def model_fn(x, t_continuous) -> noise:
|
def model_fn(x, t_continuous) -> noise:
|
||||||
t_input = get_model_input_time(t_continuous)
|
t_input = get_model_input_time(t_continuous)
|
||||||
return noise_pred(model, x, t_input, **model_kwargs)
|
return noise_pred(model, x, t_input, **model_kwargs)
|
||||||
``
|
``
|
||||||
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
||||||
|
|
||||||
@@ -359,7 +359,7 @@ class UniPC:
|
|||||||
max_val=1.,
|
max_val=1.,
|
||||||
variant='bh1',
|
variant='bh1',
|
||||||
):
|
):
|
||||||
"""Construct a UniPC.
|
"""Construct a UniPC.
|
||||||
|
|
||||||
We support both data_prediction and noise_prediction.
|
We support both data_prediction and noise_prediction.
|
||||||
"""
|
"""
|
||||||
@@ -372,7 +372,7 @@ class UniPC:
|
|||||||
|
|
||||||
def dynamic_thresholding_fn(self, x0, t=None):
|
def dynamic_thresholding_fn(self, x0, t=None):
|
||||||
"""
|
"""
|
||||||
The dynamic thresholding method.
|
The dynamic thresholding method.
|
||||||
"""
|
"""
|
||||||
dims = x0.dim()
|
dims = x0.dim()
|
||||||
p = self.dynamic_thresholding_ratio
|
p = self.dynamic_thresholding_ratio
|
||||||
@@ -404,7 +404,7 @@ class UniPC:
|
|||||||
|
|
||||||
def model_fn(self, x, t):
|
def model_fn(self, x, t):
|
||||||
"""
|
"""
|
||||||
Convert the model to the noise prediction model or the data prediction model.
|
Convert the model to the noise prediction model or the data prediction model.
|
||||||
"""
|
"""
|
||||||
if self.predict_x0:
|
if self.predict_x0:
|
||||||
return self.data_prediction_fn(x, t)
|
return self.data_prediction_fn(x, t)
|
||||||
@@ -461,7 +461,7 @@ class UniPC:
|
|||||||
|
|
||||||
def denoise_to_zero_fn(self, x, s):
|
def denoise_to_zero_fn(self, x, s):
|
||||||
"""
|
"""
|
||||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||||
"""
|
"""
|
||||||
return self.data_prediction_fn(x, s)
|
return self.data_prediction_fn(x, s)
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class UniPC:
|
|||||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||||
|
|
||||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||||
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||||
ns = self.noise_schedule
|
ns = self.noise_schedule
|
||||||
assert order <= len(model_prev_list)
|
assert order <= len(model_prev_list)
|
||||||
|
|
||||||
@@ -510,7 +510,7 @@ class UniPC:
|
|||||||
col = torch.ones_like(rks)
|
col = torch.ones_like(rks)
|
||||||
for k in range(1, K + 1):
|
for k in range(1, K + 1):
|
||||||
C.append(col)
|
C.append(col)
|
||||||
col = col * rks / (k + 1)
|
col = col * rks / (k + 1)
|
||||||
C = torch.stack(C, dim=1)
|
C = torch.stack(C, dim=1)
|
||||||
|
|
||||||
if len(D1s) > 0:
|
if len(D1s) > 0:
|
||||||
@@ -519,7 +519,6 @@ class UniPC:
|
|||||||
A_p = C_inv_p
|
A_p = C_inv_p
|
||||||
|
|
||||||
if use_corrector:
|
if use_corrector:
|
||||||
print('using corrector')
|
|
||||||
C_inv = torch.linalg.inv(C)
|
C_inv = torch.linalg.inv(C)
|
||||||
A_c = C_inv
|
A_c = C_inv
|
||||||
|
|
||||||
@@ -622,12 +621,12 @@ class UniPC:
|
|||||||
B_h = torch.expm1(hh)
|
B_h = torch.expm1(hh)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
for i in range(1, order + 1):
|
for i in range(1, order + 1):
|
||||||
R.append(torch.pow(rks, i - 1))
|
R.append(torch.pow(rks, i - 1))
|
||||||
b.append(h_phi_k * factorial_i / B_h)
|
b.append(h_phi_k * factorial_i / B_h)
|
||||||
factorial_i *= (i + 1)
|
factorial_i *= (i + 1)
|
||||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||||
|
|
||||||
R = torch.stack(R)
|
R = torch.stack(R)
|
||||||
b = torch.tensor(b, device=x.device)
|
b = torch.tensor(b, device=x.device)
|
||||||
@@ -662,7 +661,7 @@ class UniPC:
|
|||||||
|
|
||||||
if x_t is None:
|
if x_t is None:
|
||||||
if use_predictor:
|
if use_predictor:
|
||||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||||
else:
|
else:
|
||||||
pred_res = 0
|
pred_res = 0
|
||||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||||
@@ -670,7 +669,7 @@ class UniPC:
|
|||||||
if use_corrector:
|
if use_corrector:
|
||||||
model_t = self.model_fn(x_t, t)
|
model_t = self.model_fn(x_t, t)
|
||||||
if D1s is not None:
|
if D1s is not None:
|
||||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||||
else:
|
else:
|
||||||
corr_res = 0
|
corr_res = 0
|
||||||
D1_t = (model_t - model_prev_0)
|
D1_t = (model_t - model_prev_0)
|
||||||
@@ -704,7 +703,6 @@ class UniPC:
|
|||||||
):
|
):
|
||||||
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||||
# t_T = self.noise_schedule.T if t_start is None else t_start
|
# t_T = self.noise_schedule.T if t_start is None else t_start
|
||||||
device = x.device
|
|
||||||
steps = len(timesteps) - 1
|
steps = len(timesteps) - 1
|
||||||
if method == 'multistep':
|
if method == 'multistep':
|
||||||
assert steps >= order
|
assert steps >= order
|
||||||
@@ -872,4 +870,4 @@ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
||||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .ldm.modules.attention import CrossAttention
|
from .ldm.modules.attention import CrossAttention
|
||||||
|
|||||||
451
comfy/hooks.py
451
comfy/hooks.py
@@ -5,6 +5,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||||
@@ -15,130 +16,171 @@ import comfy.model_management
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from node_helpers import conditioning_set_values
|
from node_helpers import conditioning_set_values
|
||||||
|
|
||||||
|
# #######################################################################################################
|
||||||
|
# Hooks explanation
|
||||||
|
# -------------------
|
||||||
|
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
|
||||||
|
# make explicit special cases like it does for ControlNet and GLIGEN.
|
||||||
|
#
|
||||||
|
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
|
||||||
|
# that should run special code when a 'marked' cond is used in sampling.
|
||||||
|
# #######################################################################################################
|
||||||
|
|
||||||
class EnumHookMode(enum.Enum):
|
class EnumHookMode(enum.Enum):
|
||||||
|
'''
|
||||||
|
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
|
||||||
|
|
||||||
|
MinVram: No caching will occur for any operations related to hooks.
|
||||||
|
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
|
||||||
|
'''
|
||||||
MinVram = "minvram"
|
MinVram = "minvram"
|
||||||
MaxSpeed = "maxspeed"
|
MaxSpeed = "maxspeed"
|
||||||
|
|
||||||
class EnumHookType(enum.Enum):
|
class EnumHookType(enum.Enum):
|
||||||
|
'''
|
||||||
|
Hook types, each of which has different expected behavior.
|
||||||
|
'''
|
||||||
Weight = "weight"
|
Weight = "weight"
|
||||||
Patch = "patch"
|
|
||||||
ObjectPatch = "object_patch"
|
ObjectPatch = "object_patch"
|
||||||
AddModels = "add_models"
|
AdditionalModels = "add_models"
|
||||||
Callbacks = "callbacks"
|
TransformerOptions = "transformer_options"
|
||||||
Wrappers = "wrappers"
|
Injections = "add_injections"
|
||||||
SetInjections = "add_injections"
|
|
||||||
|
|
||||||
class EnumWeightTarget(enum.Enum):
|
class EnumWeightTarget(enum.Enum):
|
||||||
Model = "model"
|
Model = "model"
|
||||||
Clip = "clip"
|
Clip = "clip"
|
||||||
|
|
||||||
|
class EnumHookScope(enum.Enum):
|
||||||
|
'''
|
||||||
|
Determines if hook should be limited in its influence over sampling.
|
||||||
|
|
||||||
|
AllConditioning: hook will affect all conds used in sampling.
|
||||||
|
HookedOnly: hook will only affect the conds it was attached to.
|
||||||
|
'''
|
||||||
|
AllConditioning = "all_conditioning"
|
||||||
|
HookedOnly = "hooked_only"
|
||||||
|
|
||||||
|
|
||||||
class _HookRef:
|
class _HookRef:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# NOTE: this is an example of how the should_register function should look
|
|
||||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
|
'''Example for how custom_should_register function can look like.'''
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
|
||||||
|
'''Creates base dictionary for use with Hooks' target param.'''
|
||||||
|
d = {}
|
||||||
|
if target is not None:
|
||||||
|
d['target'] = target
|
||||||
|
d.update(kwargs)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
class Hook:
|
class Hook:
|
||||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
||||||
self.hook_type = hook_type
|
self.hook_type = hook_type
|
||||||
|
'''Enum identifying the general class of this hook.'''
|
||||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||||
|
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
|
||||||
self.hook_id = hook_id
|
self.hook_id = hook_id
|
||||||
|
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
|
||||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||||
|
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||||
self.custom_should_register = default_should_register
|
self.custom_should_register = default_should_register
|
||||||
self.auto_apply_to_nonpositive = False
|
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength(self):
|
def strength(self):
|
||||||
return self.hook_keyframe.strength
|
return self.hook_keyframe.strength
|
||||||
|
|
||||||
def initialize_timesteps(self, model: 'BaseModel'):
|
def initialize_timesteps(self, model: BaseModel):
|
||||||
self.reset()
|
self.reset()
|
||||||
self.hook_keyframe.initialize_timesteps(model)
|
self.hook_keyframe.initialize_timesteps(model)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.hook_keyframe.reset()
|
self.hook_keyframe.reset()
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: Hook = self.__class__()
|
||||||
subtype = type(self)
|
|
||||||
c: Hook = subtype()
|
|
||||||
c.hook_type = self.hook_type
|
c.hook_type = self.hook_type
|
||||||
c.hook_ref = self.hook_ref
|
c.hook_ref = self.hook_ref
|
||||||
c.hook_id = self.hook_id
|
c.hook_id = self.hook_id
|
||||||
c.hook_keyframe = self.hook_keyframe
|
c.hook_keyframe = self.hook_keyframe
|
||||||
|
c.hook_scope = self.hook_scope
|
||||||
c.custom_should_register = self.custom_should_register
|
c.custom_should_register = self.custom_should_register
|
||||||
# TODO: make this do something
|
|
||||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
return self.custom_should_register(self, model, model_options, target, registered)
|
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||||
|
|
||||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
def __eq__(self, other: Hook):
|
||||||
pass
|
|
||||||
|
|
||||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __eq__(self, other: 'Hook'):
|
|
||||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.hook_ref)
|
return hash(self.hook_ref)
|
||||||
|
|
||||||
class WeightHook(Hook):
|
class WeightHook(Hook):
|
||||||
|
'''
|
||||||
|
Hook responsible for tracking weights to be applied to some model/clip.
|
||||||
|
|
||||||
|
Note, value of hook_scope is ignored and is treated as HookedOnly.
|
||||||
|
'''
|
||||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||||
super().__init__(hook_type=EnumHookType.Weight)
|
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
|
||||||
self.weights: dict = None
|
self.weights: dict = None
|
||||||
self.weights_clip: dict = None
|
self.weights_clip: dict = None
|
||||||
self.need_weight_init = True
|
self.need_weight_init = True
|
||||||
self._strength_model = strength_model
|
self._strength_model = strength_model
|
||||||
self._strength_clip = strength_clip
|
self._strength_clip = strength_clip
|
||||||
|
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength_model(self):
|
def strength_model(self):
|
||||||
return self._strength_model * self.strength
|
return self._strength_model * self.strength
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength_clip(self):
|
def strength_clip(self):
|
||||||
return self._strength_clip * self.strength
|
return self._strength_clip * self.strength
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
weights = None
|
weights = None
|
||||||
if target == EnumWeightTarget.Model:
|
|
||||||
strength = self._strength_model
|
target = target_dict.get('target', None)
|
||||||
else:
|
if target == EnumWeightTarget.Clip:
|
||||||
strength = self._strength_clip
|
strength = self._strength_clip
|
||||||
|
else:
|
||||||
|
strength = self._strength_model
|
||||||
|
|
||||||
if self.need_weight_init:
|
if self.need_weight_init:
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if target == EnumWeightTarget.Model:
|
if target == EnumWeightTarget.Clip:
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
|
||||||
else:
|
|
||||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||||
|
else:
|
||||||
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||||
else:
|
else:
|
||||||
if target == EnumWeightTarget.Model:
|
if target == EnumWeightTarget.Clip:
|
||||||
weights = self.weights
|
|
||||||
else:
|
|
||||||
weights = self.weights_clip
|
weights = self.weights_clip
|
||||||
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
else:
|
||||||
registered.append(self)
|
weights = self.weights
|
||||||
|
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||||
|
registered.add(self)
|
||||||
return True
|
return True
|
||||||
# TODO: add logs about any keys that were not applied
|
# TODO: add logs about any keys that were not applied
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: WeightHook = super().clone()
|
||||||
subtype = type(self)
|
|
||||||
c: WeightHook = super().clone(subtype)
|
|
||||||
c.weights = self.weights
|
c.weights = self.weights
|
||||||
c.weights_clip = self.weights_clip
|
c.weights_clip = self.weights_clip
|
||||||
c.need_weight_init = self.need_weight_init
|
c.need_weight_init = self.need_weight_init
|
||||||
@@ -146,127 +188,158 @@ class WeightHook(Hook):
|
|||||||
c._strength_clip = self._strength_clip
|
c._strength_clip = self._strength_clip
|
||||||
return c
|
return c
|
||||||
|
|
||||||
class PatchHook(Hook):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(hook_type=EnumHookType.Patch)
|
|
||||||
self.patches: dict = None
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: PatchHook = super().clone(subtype)
|
|
||||||
c.patches = self.patches
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class ObjectPatchHook(Hook):
|
class ObjectPatchHook(Hook):
|
||||||
def __init__(self):
|
def __init__(self, object_patches: dict[str]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||||
self.object_patches: dict = None
|
self.object_patches = object_patches
|
||||||
|
self.hook_scope = hook_scope
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
def clone(self):
|
||||||
subtype = type(self)
|
c: ObjectPatchHook = super().clone()
|
||||||
c: ObjectPatchHook = super().clone(subtype)
|
|
||||||
c.object_patches = self.object_patches
|
c.object_patches = self.object_patches
|
||||||
return c
|
return c
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class AddModelsHook(Hook):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
|
||||||
super().__init__(hook_type=EnumHookType.AddModels)
|
|
||||||
self.key = key
|
class AdditionalModelsHook(Hook):
|
||||||
|
'''
|
||||||
|
Hook responsible for telling model management any additional models that should be loaded.
|
||||||
|
|
||||||
|
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
||||||
|
'''
|
||||||
|
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
||||||
|
super().__init__(hook_type=EnumHookType.AdditionalModels)
|
||||||
self.models = models
|
self.models = models
|
||||||
self.append_when_same = True
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: AddModelsHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
|
||||||
c.models = self.models.copy() if self.models else self.models
|
|
||||||
c.append_when_same = self.append_when_same
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class CallbackHook(Hook):
|
|
||||||
def __init__(self, key: str=None, callback: Callable=None):
|
|
||||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
|
||||||
self.key = key
|
self.key = key
|
||||||
self.callback = callback
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: AdditionalModelsHook = super().clone()
|
||||||
subtype = type(self)
|
c.models = self.models.copy() if self.models else self.models
|
||||||
c: CallbackHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
c.key = self.key
|
||||||
c.callback = self.callback
|
|
||||||
return c
|
return c
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class WrapperHook(Hook):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
|
||||||
self.wrappers_dict = wrappers_dict
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: WrapperHook = super().clone(subtype)
|
|
||||||
c.wrappers_dict = self.wrappers_dict
|
|
||||||
return c
|
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
|
||||||
if not self.should_register(model, model_options, target, registered):
|
|
||||||
return False
|
return False
|
||||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
registered.add(self)
|
||||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
|
||||||
registered.append(self)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class SetInjectionsHook(Hook):
|
class TransformerOptionsHook(Hook):
|
||||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
'''
|
||||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||||
|
'''
|
||||||
|
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
|
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||||
|
self.transformers_dict = transformers_dict
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
self._skip_adding = False
|
||||||
|
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c: TransformerOptionsHook = super().clone()
|
||||||
|
c.transformers_dict = self.transformers_dict
|
||||||
|
c._skip_adding = self._skip_adding
|
||||||
|
return c
|
||||||
|
|
||||||
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
|
return False
|
||||||
|
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||||
|
self._skip_adding = False
|
||||||
|
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||||
|
add_model_options = {"transformer_options": self.transformers_dict,
|
||||||
|
"to_load_options": self.transformers_dict}
|
||||||
|
# skip_adding if included in AllConditioning to avoid double loading
|
||||||
|
self._skip_adding = True
|
||||||
|
else:
|
||||||
|
add_model_options = {"to_load_options": self.transformers_dict}
|
||||||
|
registered.add(self)
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||||
|
if not self._skip_adding:
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||||
|
|
||||||
|
WrapperHook = TransformerOptionsHook
|
||||||
|
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||||
|
|
||||||
|
class InjectionsHook(Hook):
|
||||||
|
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
|
super().__init__(hook_type=EnumHookType.Injections)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.injections = injections
|
self.injections = injections
|
||||||
|
self.hook_scope = hook_scope
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
def clone(self):
|
||||||
subtype = type(self)
|
c: InjectionsHook = super().clone()
|
||||||
c: SetInjectionsHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
c.key = self.key
|
||||||
c.injections = self.injections.copy() if self.injections else self.injections
|
c.injections = self.injections.copy() if self.injections else self.injections
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
# TODO: add functionality
|
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
|
||||||
pass
|
|
||||||
|
|
||||||
class HookGroup:
|
class HookGroup:
|
||||||
|
'''
|
||||||
|
Stores groups of hooks, and allows them to be queried by type.
|
||||||
|
|
||||||
|
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
|
||||||
|
always use the provided functions on HookGroup.
|
||||||
|
'''
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hooks: list[Hook] = []
|
self.hooks: list[Hook] = []
|
||||||
|
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.hooks)
|
||||||
|
|
||||||
def add(self, hook: Hook):
|
def add(self, hook: Hook):
|
||||||
if hook not in self.hooks:
|
if hook not in self.hooks:
|
||||||
self.hooks.append(hook)
|
self.hooks.append(hook)
|
||||||
|
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
|
||||||
|
|
||||||
|
def remove(self, hook: Hook):
|
||||||
|
if hook in self.hooks:
|
||||||
|
self.hooks.remove(hook)
|
||||||
|
self._hook_dict[hook.hook_type].remove(hook)
|
||||||
|
|
||||||
|
def get_type(self, hook_type: EnumHookType):
|
||||||
|
return self._hook_dict.get(hook_type, [])
|
||||||
|
|
||||||
def contains(self, hook: Hook):
|
def contains(self, hook: Hook):
|
||||||
return hook in self.hooks
|
return hook in self.hooks
|
||||||
|
|
||||||
|
def is_subset_of(self, other: HookGroup):
|
||||||
|
self_hooks = set(self.hooks)
|
||||||
|
other_hooks = set(other.hooks)
|
||||||
|
return self_hooks.issubset(other_hooks)
|
||||||
|
|
||||||
|
def new_with_common_hooks(self, other: HookGroup):
|
||||||
|
c = HookGroup()
|
||||||
|
for hook in self.hooks:
|
||||||
|
if other.contains(hook):
|
||||||
|
c.add(hook.clone())
|
||||||
|
return c
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookGroup()
|
c = HookGroup()
|
||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
c.add(hook.clone())
|
c.add(hook.clone())
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def clone_and_combine(self, other: 'HookGroup'):
|
def clone_and_combine(self, other: HookGroup):
|
||||||
c = self.clone()
|
c = self.clone()
|
||||||
if other is not None:
|
if other is not None:
|
||||||
for hook in other.hooks:
|
for hook in other.hooks:
|
||||||
c.add(hook.clone())
|
c.add(hook.clone())
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
|
||||||
if hook_kf is None:
|
if hook_kf is None:
|
||||||
hook_kf = HookKeyframeGroup()
|
hook_kf = HookKeyframeGroup()
|
||||||
else:
|
else:
|
||||||
@@ -274,36 +347,29 @@ class HookGroup:
|
|||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
hook.hook_keyframe = hook_kf
|
hook.hook_keyframe = hook_kf
|
||||||
|
|
||||||
def get_dict_repr(self):
|
|
||||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
|
||||||
for hook in self.hooks:
|
|
||||||
with_type = d.setdefault(hook.hook_type, {})
|
|
||||||
with_type[hook] = None
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_hooks_for_clip_schedule(self):
|
def get_hooks_for_clip_schedule(self):
|
||||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||||
for hook in self.hooks:
|
# only care about WeightHooks, for now
|
||||||
# only care about WeightHooks, for now
|
for hook in self.get_type(EnumHookType.Weight):
|
||||||
if hook.hook_type == EnumHookType.Weight:
|
hook: WeightHook
|
||||||
hook_schedule = []
|
hook_schedule = []
|
||||||
# if no hook keyframes, assign default value
|
# if no hook keyframes, assign default value
|
||||||
if len(hook.hook_keyframe.keyframes) == 0:
|
if len(hook.hook_keyframe.keyframes) == 0:
|
||||||
hook_schedule.append(((0.0, 1.0), None))
|
hook_schedule.append(((0.0, 1.0), None))
|
||||||
scheduled_hooks[hook] = hook_schedule
|
|
||||||
continue
|
|
||||||
# find ranges of values
|
|
||||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
|
||||||
for keyframe in hook.hook_keyframe.keyframes:
|
|
||||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
# create final range, assuming last start_percent was not 1.0
|
|
||||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
|
||||||
scheduled_hooks[hook] = hook_schedule
|
scheduled_hooks[hook] = hook_schedule
|
||||||
|
continue
|
||||||
|
# find ranges of values
|
||||||
|
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||||
|
for keyframe in hook.hook_keyframe.keyframes:
|
||||||
|
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
# create final range, assuming last start_percent was not 1.0
|
||||||
|
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||||
|
scheduled_hooks[hook] = hook_schedule
|
||||||
# hooks should not have their schedules in a list of tuples
|
# hooks should not have their schedules in a list of tuples
|
||||||
all_ranges: list[tuple[float, float]] = []
|
all_ranges: list[tuple[float, float]] = []
|
||||||
for range_kfs in scheduled_hooks.values():
|
for range_kfs in scheduled_hooks.values():
|
||||||
@@ -335,7 +401,7 @@ class HookGroup:
|
|||||||
hook.reset()
|
hook.reset()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
|
||||||
actual: list[HookGroup] = []
|
actual: list[HookGroup] = []
|
||||||
for group in hooks_list:
|
for group in hooks_list:
|
||||||
if group is not None:
|
if group is not None:
|
||||||
@@ -364,10 +430,16 @@ class HookKeyframe:
|
|||||||
self.start_percent = float(start_percent)
|
self.start_percent = float(start_percent)
|
||||||
self.start_t = 999999999.9
|
self.start_t = 999999999.9
|
||||||
self.guarantee_steps = guarantee_steps
|
self.guarantee_steps = guarantee_steps
|
||||||
|
|
||||||
|
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
|
||||||
|
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
|
||||||
|
if self.start_t > max_sigma:
|
||||||
|
return 0
|
||||||
|
return self.guarantee_steps
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookKeyframe(strength=self.strength,
|
c = HookKeyframe(strength=self.strength,
|
||||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||||
c.start_t = self.start_t
|
c.start_t = self.start_t
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -394,7 +466,7 @@ class HookKeyframeGroup:
|
|||||||
self._current_strength = None
|
self._current_strength = None
|
||||||
self.curr_t = -1.
|
self.curr_t = -1.
|
||||||
self._set_first_as_current()
|
self._set_first_as_current()
|
||||||
|
|
||||||
def add(self, keyframe: HookKeyframe):
|
def add(self, keyframe: HookKeyframe):
|
||||||
# add to end of list, then sort
|
# add to end of list, then sort
|
||||||
self.keyframes.append(keyframe)
|
self.keyframes.append(keyframe)
|
||||||
@@ -406,33 +478,40 @@ class HookKeyframeGroup:
|
|||||||
self._current_keyframe = self.keyframes[0]
|
self._current_keyframe = self.keyframes[0]
|
||||||
else:
|
else:
|
||||||
self._current_keyframe = None
|
self._current_keyframe = None
|
||||||
|
|
||||||
|
def has_guarantee_steps(self):
|
||||||
|
for kf in self.keyframes:
|
||||||
|
if kf.guarantee_steps > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def has_index(self, index: int):
|
def has_index(self, index: int):
|
||||||
return index >= 0 and index < len(self.keyframes)
|
return index >= 0 and index < len(self.keyframes)
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.keyframes) == 0
|
return len(self.keyframes) == 0
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookKeyframeGroup()
|
c = HookKeyframeGroup()
|
||||||
for keyframe in self.keyframes:
|
for keyframe in self.keyframes:
|
||||||
c.keyframes.append(keyframe.clone())
|
c.keyframes.append(keyframe.clone())
|
||||||
c._set_first_as_current()
|
c._set_first_as_current()
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def initialize_timesteps(self, model: 'BaseModel'):
|
def initialize_timesteps(self, model: BaseModel):
|
||||||
for keyframe in self.keyframes:
|
for keyframe in self.keyframes:
|
||||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||||
|
|
||||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
|
||||||
if self.is_empty():
|
if self.is_empty():
|
||||||
return False
|
return False
|
||||||
if curr_t == self._curr_t:
|
if curr_t == self._curr_t:
|
||||||
return False
|
return False
|
||||||
|
max_sigma = torch.max(transformer_options["sample_sigmas"])
|
||||||
prev_index = self._current_index
|
prev_index = self._current_index
|
||||||
prev_strength = self._current_strength
|
prev_strength = self._current_strength
|
||||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
|
||||||
# if has next index, loop through and see if need to switch
|
# if has next index, loop through and see if need to switch
|
||||||
if self.has_index(self._current_index+1):
|
if self.has_index(self._current_index+1):
|
||||||
for i in range(self._current_index+1, len(self.keyframes)):
|
for i in range(self._current_index+1, len(self.keyframes)):
|
||||||
@@ -445,7 +524,7 @@ class HookKeyframeGroup:
|
|||||||
self._current_keyframe = eval_c
|
self._current_keyframe = eval_c
|
||||||
self._current_used_steps = 0
|
self._current_used_steps = 0
|
||||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||||
if self._current_keyframe.guarantee_steps > 0:
|
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||||
break
|
break
|
||||||
# if eval_c is outside the percent range, stop looking further
|
# if eval_c is outside the percent range, stop looking further
|
||||||
else: break
|
else: break
|
||||||
@@ -508,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
|||||||
sorted_list.extend(object_list)
|
sorted_list.extend(object_list)
|
||||||
return sorted_list
|
return sorted_list
|
||||||
|
|
||||||
|
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
||||||
|
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
||||||
|
if hooks is None or model.is_clip:
|
||||||
|
return {}
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
|
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
||||||
|
hook: TransformerOptionsHook
|
||||||
|
hook.on_apply_hooks(model, transformer_options)
|
||||||
|
return transformer_options
|
||||||
|
|
||||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||||
hook_group = HookGroup()
|
hook_group = HookGroup()
|
||||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||||
@@ -534,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
|
|||||||
hook.need_weight_init = False
|
hook.need_weight_init = False
|
||||||
return hook_group
|
return hook_group
|
||||||
|
|
||||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
|
||||||
if model is None:
|
if model is None:
|
||||||
return None
|
return None
|
||||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||||
@@ -546,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
|
|||||||
return patches_model
|
return patches_model
|
||||||
|
|
||||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
|
||||||
strength_model: float, strength_clip: float):
|
strength_model: float, strength_clip: float):
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if model is not None:
|
if model is not None:
|
||||||
@@ -564,7 +654,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
|
|||||||
else:
|
else:
|
||||||
k = ()
|
k = ()
|
||||||
new_modelpatcher = None
|
new_modelpatcher = None
|
||||||
|
|
||||||
if clip is not None:
|
if clip is not None:
|
||||||
new_clip = clip.clone()
|
new_clip = clip.clone()
|
||||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||||
@@ -575,7 +665,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
|
|||||||
k1 = set(k1)
|
k1 = set(k1)
|
||||||
for x in loaded:
|
for x in loaded:
|
||||||
if (x not in k) and (x not in k1):
|
if (x not in k) and (x not in k1):
|
||||||
print(f"NOT LOADED {x}")
|
logging.warning(f"NOT LOADED {x}")
|
||||||
return (new_modelpatcher, new_clip, hook_group)
|
return (new_modelpatcher, new_clip, hook_group)
|
||||||
|
|
||||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||||
@@ -598,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
|
|||||||
else:
|
else:
|
||||||
c_dict[hooks_key] = cache[hooks_tuple]
|
c_dict[hooks_key] = cache[hooks_tuple]
|
||||||
|
|
||||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
|
||||||
|
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||||
c = []
|
c = []
|
||||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
if cache is None:
|
||||||
|
cache = {}
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
for k in values:
|
for k in values:
|
||||||
if append_hooks and k == 'hooks':
|
if append_hooks and k == 'hooks':
|
||||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
_combine_hooks_from_values(n[1], values, cache)
|
||||||
else:
|
else:
|
||||||
n[1][k] = values[k]
|
n[1][k] = values[k]
|
||||||
c.append(n)
|
c.append(n)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||||
if hooks is None:
|
if hooks is None:
|
||||||
return cond
|
return cond
|
||||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
|
||||||
|
|
||||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||||
if timestep_range is None:
|
if timestep_range is None:
|
||||||
@@ -650,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
|
|||||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
final_conds = []
|
final_conds = []
|
||||||
|
cache = {}
|
||||||
for c in conds:
|
for c in conds:
|
||||||
# first, apply lora_hook to conditioning, if provided
|
# first, apply lora_hook to conditioning, if provided
|
||||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, apply mask to conditioning
|
# next, apply mask to conditioning
|
||||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
@@ -664,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
|||||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
combined_conds = []
|
combined_conds = []
|
||||||
|
cache = {}
|
||||||
for c, masked_c in zip(conds, new_conds):
|
for c, masked_c in zip(conds, new_conds):
|
||||||
# first, apply lora_hook to new conditioning, if provided
|
# first, apply lora_hook to new conditioning, if provided
|
||||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, apply mask to new conditioning, if provided
|
# next, apply mask to new conditioning, if provided
|
||||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
@@ -678,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
|
|||||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
combined_conds = []
|
combined_conds = []
|
||||||
|
cache = {}
|
||||||
for c, new_c in zip(conds, new_conds):
|
for c, new_c in zip(conds, new_conds):
|
||||||
# first, apply lora_hook to new conditioning, if provided
|
# first, apply lora_hook to new conditioning, if provided
|
||||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||||
new_c = conditioning_set_values(new_c, {'default': True})
|
new_c = conditioning_set_values(new_c, {'default': True})
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import numpy as np
|
|||||||
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
|
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
|
||||||
|
|
||||||
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
|
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
|
||||||
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
|
||||||
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
||||||
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
|
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
|
||||||
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
|
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
|||||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||||
"""Constructs a continuous VP noise schedule."""
|
"""Constructs a continuous VP noise schedule."""
|
||||||
t = torch.linspace(1, eps_s, n, device=device)
|
t = torch.linspace(1, eps_s, n, device=device)
|
||||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
|
||||||
return append_zero(sigmas)
|
return append_zero(sigmas)
|
||||||
|
|
||||||
|
|
||||||
@@ -70,8 +70,14 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
|||||||
return sigma_down, sigma_up
|
return sigma_down, sigma_up
|
||||||
|
|
||||||
|
|
||||||
def default_noise_sampler(x):
|
def default_noise_sampler(x, seed=None):
|
||||||
return lambda sigma, sigma_next: torch.randn_like(x)
|
if seed is not None:
|
||||||
|
generator = torch.Generator(device=x.device)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
class BatchedBrownianTree:
|
class BatchedBrownianTree:
|
||||||
@@ -168,7 +174,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -189,7 +196,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -290,7 +298,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
|
|
||||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -318,7 +327,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -465,7 +475,7 @@ class DPMSolver(nn.Module):
|
|||||||
return x_3, eps_cache
|
return x_3, eps_cache
|
||||||
|
|
||||||
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||||
if not t_end > t_start and eta:
|
if not t_end > t_start and eta:
|
||||||
raise ValueError('eta must be 0 for reverse sampling')
|
raise ValueError('eta must be 0 for reverse sampling')
|
||||||
|
|
||||||
@@ -504,7 +514,7 @@ class DPMSolver(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||||
if order not in {2, 3}:
|
if order not in {2, 3}:
|
||||||
raise ValueError('order should be 2 or 3')
|
raise ValueError('order should be 2 or 3')
|
||||||
forward = t_end > t_start
|
forward = t_end > t_start
|
||||||
@@ -591,7 +601,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
|
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
t_fn = lambda sigma: sigma.log().neg()
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
@@ -625,7 +636,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||||
@@ -882,7 +894,8 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
|||||||
|
|
||||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
@@ -902,7 +915,8 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -1153,7 +1167,8 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
temp = [0]
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
@@ -1179,7 +1194,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
temp = [0]
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
@@ -1230,7 +1246,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
nonlocal uncond_denoised
|
nonlocal uncond_denoised
|
||||||
uncond_denoised = args["uncond_denoised"]
|
uncond_denoised = args["uncond_denoised"]
|
||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
@@ -1249,3 +1265,97 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||||
old_uncond_denoised = uncond_denoised
|
old_uncond_denoised = uncond_denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
phi1_fn = lambda t: torch.expm1(t) / t
|
||||||
|
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||||
|
|
||||||
|
old_denoised = None
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
if cfg_pp:
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
if s_churn > 0:
|
||||||
|
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
else:
|
||||||
|
gamma = 0
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
|
if gamma > 0:
|
||||||
|
eps = torch.randn_like(x) * s_noise
|
||||||
|
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
||||||
|
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
|
||||||
|
if sigmas[i + 1] == 0 or old_denoised is None:
|
||||||
|
# Euler method
|
||||||
|
if cfg_pp:
|
||||||
|
d = to_d(x, sigma_hat, uncond_denoised)
|
||||||
|
x = denoised + d * sigmas[i + 1]
|
||||||
|
else:
|
||||||
|
d = to_d(x, sigma_hat, denoised)
|
||||||
|
dt = sigmas[i + 1] - sigma_hat
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||||
|
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
|
||||||
|
h = t_next - t
|
||||||
|
c2 = (t_prev - t) / h
|
||||||
|
|
||||||
|
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||||
|
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
|
||||||
|
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
|
||||||
|
|
||||||
|
if cfg_pp:
|
||||||
|
x = x + (denoised - uncond_denoised)
|
||||||
|
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||||
|
|
||||||
|
old_denoised = denoised
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||||
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||||
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||||
|
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
old_d = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
d = to_d(x, sigmas[i], denoised)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
dt = sigmas[i + 1] - sigmas[i]
|
||||||
|
if i == 0:
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# Gradient estimation
|
||||||
|
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||||
|
x = x + d_bar * dt
|
||||||
|
old_d = d
|
||||||
|
return x
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import torch
|
|||||||
class LatentFormat:
|
class LatentFormat:
|
||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
|
latent_dimensions = 2
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
latent_rgb_factors_bias = None
|
latent_rgb_factors_bias = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
@@ -143,6 +144,7 @@ class SD3(LatentFormat):
|
|||||||
|
|
||||||
class StableAudio1(LatentFormat):
|
class StableAudio1(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
|
latent_dimensions = 1
|
||||||
|
|
||||||
class Flux(SD3):
|
class Flux(SD3):
|
||||||
latent_channels = 16
|
latent_channels = 16
|
||||||
@@ -178,6 +180,7 @@ class Flux(SD3):
|
|||||||
|
|
||||||
class Mochi(LatentFormat):
|
class Mochi(LatentFormat):
|
||||||
latent_channels = 12
|
latent_channels = 12
|
||||||
|
latent_dimensions = 3
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 1.0
|
self.scale_factor = 1.0
|
||||||
@@ -219,6 +222,8 @@ class Mochi(LatentFormat):
|
|||||||
|
|
||||||
class LTXV(LatentFormat):
|
class LTXV(LatentFormat):
|
||||||
latent_channels = 128
|
latent_channels = 128
|
||||||
|
latent_dimensions = 3
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||||
@@ -352,3 +357,53 @@ class LTXV(LatentFormat):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||||
|
|
||||||
|
class HunyuanVideo(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
latent_dimensions = 3
|
||||||
|
scale_factor = 0.476986
|
||||||
|
latent_rgb_factors = [
|
||||||
|
[-0.0395, -0.0331, 0.0445],
|
||||||
|
[ 0.0696, 0.0795, 0.0518],
|
||||||
|
[ 0.0135, -0.0945, -0.0282],
|
||||||
|
[ 0.0108, -0.0250, -0.0765],
|
||||||
|
[-0.0209, 0.0032, 0.0224],
|
||||||
|
[-0.0804, -0.0254, -0.0639],
|
||||||
|
[-0.0991, 0.0271, -0.0669],
|
||||||
|
[-0.0646, -0.0422, -0.0400],
|
||||||
|
[-0.0696, -0.0595, -0.0894],
|
||||||
|
[-0.0799, -0.0208, -0.0375],
|
||||||
|
[ 0.1166, 0.1627, 0.0962],
|
||||||
|
[ 0.1165, 0.0432, 0.0407],
|
||||||
|
[-0.2315, -0.1920, -0.1355],
|
||||||
|
[-0.0270, 0.0401, -0.0821],
|
||||||
|
[-0.0616, -0.0997, -0.0727],
|
||||||
|
[ 0.0249, -0.0469, -0.1703]
|
||||||
|
]
|
||||||
|
|
||||||
|
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||||
|
|
||||||
|
class Cosmos1CV8x8x8(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
latent_dimensions = 3
|
||||||
|
|
||||||
|
latent_rgb_factors = [
|
||||||
|
[ 0.1817, 0.2284, 0.2423],
|
||||||
|
[-0.0586, -0.0862, -0.3108],
|
||||||
|
[-0.4703, -0.4255, -0.3995],
|
||||||
|
[ 0.0803, 0.1963, 0.1001],
|
||||||
|
[-0.0820, -0.1050, 0.0400],
|
||||||
|
[ 0.2511, 0.3098, 0.2787],
|
||||||
|
[-0.1830, -0.2117, -0.0040],
|
||||||
|
[-0.0621, -0.2187, -0.0939],
|
||||||
|
[ 0.3619, 0.1082, 0.1455],
|
||||||
|
[ 0.3164, 0.3922, 0.2575],
|
||||||
|
[ 0.1152, 0.0231, -0.0462],
|
||||||
|
[-0.1434, -0.3609, -0.3665],
|
||||||
|
[ 0.0635, 0.1471, 0.1680],
|
||||||
|
[-0.3635, -0.1963, -0.3248],
|
||||||
|
[-0.1865, 0.0365, 0.2346],
|
||||||
|
[ 0.0447, 0.0994, 0.0881]
|
||||||
|
]
|
||||||
|
|
||||||
|
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Literal, Dict, Any
|
from typing import Literal
|
||||||
import math
|
import math
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@@ -97,7 +97,7 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False,
|
|||||||
raise ValueError(f"Unknown activation {activation}")
|
raise ValueError(f"Unknown activation {activation}")
|
||||||
|
|
||||||
if antialias:
|
if antialias:
|
||||||
act = Activation1d(act)
|
act = Activation1d(act) # noqa: F821 Activation1d is not defined
|
||||||
|
|
||||||
return act
|
return act
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,6 @@ class RotaryEmbedding(nn.Module):
|
|||||||
def forward(self, t):
|
def forward(self, t):
|
||||||
# device = self.inv_freq.device
|
# device = self.inv_freq.device
|
||||||
device = t.device
|
device = t.device
|
||||||
dtype = t.dtype
|
|
||||||
|
|
||||||
# t = t.to(torch.float32)
|
# t = t.to(torch.float32)
|
||||||
|
|
||||||
@@ -170,7 +169,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
if self.scale is None:
|
if self.scale is None:
|
||||||
return freqs, 1.
|
return freqs, 1.
|
||||||
|
|
||||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
|
||||||
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
||||||
scale = torch.cat((scale, scale), dim = -1)
|
scale = torch.cat((scale, scale), dim = -1)
|
||||||
|
|
||||||
@@ -229,9 +228,9 @@ class FeedForward(nn.Module):
|
|||||||
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||||
else:
|
else:
|
||||||
linear_in = nn.Sequential(
|
linear_in = nn.Sequential(
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
activation
|
activation
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -246,9 +245,9 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
self.ff = nn.Sequential(
|
self.ff = nn.Sequential(
|
||||||
linear_in,
|
linear_in,
|
||||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||||
linear_out,
|
linear_out,
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -346,18 +345,13 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# determine masking
|
# determine masking
|
||||||
masks = []
|
masks = []
|
||||||
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
|
||||||
|
|
||||||
if input_mask is not None:
|
if input_mask is not None:
|
||||||
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||||
masks.append(~input_mask)
|
masks.append(~input_mask)
|
||||||
|
|
||||||
# Other masks will be added here later
|
# Other masks will be added here later
|
||||||
|
n = q.shape[-2]
|
||||||
if len(masks) > 0:
|
|
||||||
final_attn_mask = ~or_reduce(masks)
|
|
||||||
|
|
||||||
n, device = q.shape[-2], q.device
|
|
||||||
|
|
||||||
causal = self.causal if causal is None else causal
|
causal = self.causal if causal is None else causal
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor, einsum
|
from torch import Tensor
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
from typing import List, Union
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import math
|
import math
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|||||||
@@ -147,7 +147,6 @@ class DoubleAttention(nn.Module):
|
|||||||
|
|
||||||
bsz, seqlen1, _ = c.shape
|
bsz, seqlen1, _ = c.shape
|
||||||
bsz, seqlen2, _ = x.shape
|
bsz, seqlen2, _ = x.shape
|
||||||
seqlen = seqlen1 + seqlen2
|
|
||||||
|
|
||||||
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||||
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
@@ -382,7 +381,6 @@ class MMDiT(nn.Module):
|
|||||||
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
||||||
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
||||||
self.h_max, self.w_max = target_dim
|
self.h_max, self.w_max = target_dim
|
||||||
print("PE extended to", target_dim)
|
|
||||||
|
|
||||||
def pe_selection_index_based_on_dim(self, h, w):
|
def pe_selection_index_based_on_dim(self, h, w):
|
||||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .common import LayerNorm2d_op
|
from .common import LayerNorm2d_op
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class StageB(nn.Module):
|
|||||||
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||||
#
|
#
|
||||||
# # blocks
|
# # blocks
|
||||||
# for level_block in self.down_blocks + self.up_blocks:
|
# for level_block in self.down_blocks + self.up_blocks:
|
||||||
# for block in level_block:
|
# for block in level_block:
|
||||||
@@ -148,7 +148,7 @@ class StageB(nn.Module):
|
|||||||
# for layer in block.modules():
|
# for layer in block.modules():
|
||||||
# if isinstance(layer, nn.Linear):
|
# if isinstance(layer, nn.Linear):
|
||||||
# nn.init.constant_(layer.weight, 0)
|
# nn.init.constant_(layer.weight, 0)
|
||||||
#
|
#
|
||||||
# def _init_weights(self, m):
|
# def _init_weights(self, m):
|
||||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||||
# torch.nn.init.xavier_uniform_(m.weight)
|
# torch.nn.init.xavier_uniform_(m.weight)
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class StageC(nn.Module):
|
|||||||
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
||||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||||
#
|
#
|
||||||
# # blocks
|
# # blocks
|
||||||
# for level_block in self.down_blocks + self.up_blocks:
|
# for level_block in self.down_blocks + self.up_blocks:
|
||||||
# for block in level_block:
|
# for block in level_block:
|
||||||
@@ -152,7 +152,7 @@ class StageC(nn.Module):
|
|||||||
# for layer in block.modules():
|
# for layer in block.modules():
|
||||||
# if isinstance(layer, nn.Linear):
|
# if isinstance(layer, nn.Linear):
|
||||||
# nn.init.constant_(layer.weight, 0)
|
# nn.init.constant_(layer.weight, 0)
|
||||||
#
|
#
|
||||||
# def _init_weights(self, m):
|
# def _init_weights(self, m):
|
||||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||||
# torch.nn.init.xavier_uniform_(m.weight)
|
# torch.nn.init.xavier_uniform_(m.weight)
|
||||||
|
|||||||
@@ -4,9 +4,12 @@ import comfy.ops
|
|||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
padding_mode = "reflect"
|
padding_mode = "reflect"
|
||||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
|
||||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
pad = ()
|
||||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
for i in range(img.ndim - 2):
|
||||||
|
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
|
||||||
|
|
||||||
|
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
|||||||
808
comfy/ldm/cosmos/blocks.py
Normal file
808
comfy/ldm/cosmos/blocks.py
Normal file
@@ -0,0 +1,808 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(
|
||||||
|
t: torch.Tensor,
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||||
|
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||||
|
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||||
|
return t_out
|
||||||
|
|
||||||
|
|
||||||
|
def get_normalization(name: str, channels: int, weight_args={}):
|
||||||
|
if name == "I":
|
||||||
|
return nn.Identity()
|
||||||
|
elif name == "R":
|
||||||
|
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization {name} not found")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAttentionOp(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
Generalized attention impl.
|
||||||
|
|
||||||
|
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
|
||||||
|
If `context_dim` is None, self-attention is assumed.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
query_dim (int): Dimension of each query vector.
|
||||||
|
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
|
||||||
|
heads (int, optional): Number of attention heads. Defaults to 8.
|
||||||
|
dim_head (int, optional): Dimension of each head. Defaults to 64.
|
||||||
|
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
|
||||||
|
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
|
||||||
|
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
|
||||||
|
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
|
||||||
|
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
|
||||||
|
Defaults to "SSI".
|
||||||
|
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
|
||||||
|
Defaults to 'per_head'. Only support 'per_head'.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
|
||||||
|
>>> query = torch.randn(10, 128) # Batch size of 10
|
||||||
|
>>> context = torch.randn(10, 256) # Batch size of 10
|
||||||
|
>>> output = attn(query, context) # Perform the attention operation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
context_dim=None,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
attn_op: Optional[BaseAttentionOp] = None,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
out_bias: bool = False,
|
||||||
|
qkv_norm: str = "SSI",
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
backend: str = "transformer_engine",
|
||||||
|
qkv_format: str = "bshd",
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_selfattn = context_dim is None # self attention
|
||||||
|
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
context_dim = query_dim if context_dim is None else context_dim
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.qkv_norm_mode = qkv_norm_mode
|
||||||
|
self.qkv_format = qkv_format
|
||||||
|
|
||||||
|
if self.qkv_norm_mode == "per_head":
|
||||||
|
norm_dim = dim_head
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||||
|
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
self.to_q = nn.Sequential(
|
||||||
|
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[0], norm_dim),
|
||||||
|
)
|
||||||
|
self.to_k = nn.Sequential(
|
||||||
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[1], norm_dim),
|
||||||
|
)
|
||||||
|
self.to_v = nn.Sequential(
|
||||||
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[2], norm_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def cal_qkv(
|
||||||
|
self, x, context=None, mask=None, rope_emb=None, **kwargs
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
|
||||||
|
Before 07/24/2024, these modules normalize across all heads.
|
||||||
|
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
|
||||||
|
we support to normalize per head.
|
||||||
|
To keep the checkpoint copatibility with the previous code,
|
||||||
|
we keep the nn.Sequential but call the projection and the normalization layers separately.
|
||||||
|
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
|
||||||
|
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
|
||||||
|
"""
|
||||||
|
if self.qkv_norm_mode == "per_head":
|
||||||
|
q = self.to_q[0](x)
|
||||||
|
context = x if context is None else context
|
||||||
|
k = self.to_k[0](context)
|
||||||
|
v = self.to_v[0](context)
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||||
|
|
||||||
|
q = self.to_q[1](q)
|
||||||
|
k = self.to_k[1](k)
|
||||||
|
v = self.to_v[1](v)
|
||||||
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||||
|
# apply_rotary_pos_emb inlined
|
||||||
|
q_shape = q.shape
|
||||||
|
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||||
|
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
||||||
|
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
||||||
|
|
||||||
|
# apply_rotary_pos_emb inlined
|
||||||
|
k_shape = k.shape
|
||||||
|
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||||
|
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
||||||
|
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context=None,
|
||||||
|
mask=None,
|
||||||
|
rope_emb=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||||
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
|
"""
|
||||||
|
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||||
|
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||||
|
del q, k, v
|
||||||
|
out = rearrange(out, " b n s c -> s b (n c)")
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer FFN with optional gating
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
d_model (int): Dimensionality of input features.
|
||||||
|
d_ff (int): Dimensionality of the hidden layer.
|
||||||
|
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
|
||||||
|
activation (callable, optional): The activation function applied after the first linear layer.
|
||||||
|
Defaults to nn.ReLU().
|
||||||
|
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
|
||||||
|
Defaults to False.
|
||||||
|
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> ff = FeedForward(d_model=512, d_ff=2048)
|
||||||
|
>>> x = torch.randn(64, 10, 512) # Example input tensor
|
||||||
|
>>> output = ff(x)
|
||||||
|
>>> print(output.shape) # Expected shape: (64, 10, 512)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
d_ff: int,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
activation=nn.ReLU(),
|
||||||
|
is_gated: bool = False,
|
||||||
|
bias: bool = False,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
|
||||||
|
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.activation = activation
|
||||||
|
self.is_gated = is_gated
|
||||||
|
if is_gated:
|
||||||
|
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
g = self.activation(self.layer1(x))
|
||||||
|
if self.is_gated:
|
||||||
|
x = g * self.linear_gate(x)
|
||||||
|
else:
|
||||||
|
x = g
|
||||||
|
assert self.dropout.p == 0.0, "we skip dropout"
|
||||||
|
return self.layer2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2FeedForward(FeedForward):
|
||||||
|
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
|
||||||
|
super().__init__(
|
||||||
|
d_model=d_model,
|
||||||
|
d_ff=d_ff,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=nn.GELU(),
|
||||||
|
is_gated=False,
|
||||||
|
bias=bias,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
assert self.dropout.p == 0.0, "we skip dropout"
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / (half_dim - 0.0)
|
||||||
|
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
|
||||||
|
sin_emb = torch.sin(emb)
|
||||||
|
cos_emb = torch.cos(emb)
|
||||||
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||||
|
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
logging.debug(
|
||||||
|
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||||
|
)
|
||||||
|
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
|
||||||
|
else:
|
||||||
|
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = self.linear_1(sample)
|
||||||
|
emb = self.activation(emb)
|
||||||
|
emb = self.linear_2(emb)
|
||||||
|
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
adaln_lora_B_3D = emb
|
||||||
|
emb_B_D = sample
|
||||||
|
else:
|
||||||
|
emb_B_D = emb
|
||||||
|
adaln_lora_B_3D = None
|
||||||
|
|
||||||
|
return emb_B_D, adaln_lora_B_3D
|
||||||
|
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
|
||||||
|
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
|
||||||
|
|
||||||
|
[B] -> [B, D]
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_channels (int): The number of Fourier features to generate.
|
||||||
|
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
|
||||||
|
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
|
||||||
|
the variance of the features. Defaults to False.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
|
||||||
|
>>> x = torch.randn(10, 256) # Example input tensor
|
||||||
|
>>> output = layer(x)
|
||||||
|
>>> print(output.shape) # Expected shape: (10, 256)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_channels, bandwidth=1, normalize=False):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||||
|
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||||
|
self.gain = np.sqrt(2) if normalize else 1
|
||||||
|
|
||||||
|
def forward(self, x, gain: float = 1.0):
|
||||||
|
"""
|
||||||
|
Apply the Fourier feature transformation to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The transformed tensor, with Fourier features applied.
|
||||||
|
"""
|
||||||
|
in_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||||
|
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||||
|
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||||
|
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||||
|
and embedding each patch into a vector of size `out_channels`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- spatial_patch_size (int): The size of each spatial patch.
|
||||||
|
- temporal_patch_size (int): The size of each temporal patch.
|
||||||
|
- in_channels (int): Number of input channels. Default: 3.
|
||||||
|
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||||
|
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spatial_patch_size,
|
||||||
|
temporal_patch_size,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=768,
|
||||||
|
bias=True,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.spatial_patch_size = spatial_patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
Rearrange(
|
||||||
|
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||||
|
r=temporal_patch_size,
|
||||||
|
m=spatial_patch_size,
|
||||||
|
n=spatial_patch_size,
|
||||||
|
),
|
||||||
|
operations.Linear(
|
||||||
|
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.out = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the PatchEmbed module.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||||
|
B is the batch size,
|
||||||
|
C is the number of channels,
|
||||||
|
T is the temporal dimension,
|
||||||
|
H is the height, and
|
||||||
|
W is the width of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||||
|
"""
|
||||||
|
assert x.dim() == 5
|
||||||
|
_, _, T, H, W = x.shape
|
||||||
|
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||||
|
assert T % self.temporal_patch_size == 0
|
||||||
|
x = self.proj(x)
|
||||||
|
return self.out(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of video DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
spatial_patch_size,
|
||||||
|
temporal_patch_size,
|
||||||
|
out_channels,
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
|
||||||
|
self.linear = operations.Linear(
|
||||||
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
|
||||||
|
)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.n_adaln_chunks = 2
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
|
||||||
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_BT_HW_D,
|
||||||
|
emb_B_D,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
assert adaln_lora_B_3D is not None
|
||||||
|
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
|
||||||
|
2, dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
|
||||||
|
|
||||||
|
B = emb_B_D.shape[0]
|
||||||
|
T = x_BT_HW_D.shape[0] // B
|
||||||
|
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
|
||||||
|
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
|
||||||
|
|
||||||
|
x_BT_HW_D = self.linear(x_BT_HW_D)
|
||||||
|
return x_BT_HW_D
|
||||||
|
|
||||||
|
|
||||||
|
class VideoAttn(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements video attention with optional cross-attention capabilities.
|
||||||
|
|
||||||
|
This module processes video features while maintaining their spatio-temporal structure. It can perform
|
||||||
|
self-attention within the video features or cross-attention with external context features.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_dim (int): Dimension of input feature vectors
|
||||||
|
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
bias (bool): Whether to include bias in attention projections. Default: False
|
||||||
|
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
|
||||||
|
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
|
||||||
|
|
||||||
|
Input shape:
|
||||||
|
- x: (T, H, W, B, D) video features
|
||||||
|
- context (optional): (M, B, D) context features for cross-attention
|
||||||
|
where:
|
||||||
|
T: temporal dimension
|
||||||
|
H: height
|
||||||
|
W: width
|
||||||
|
B: batch size
|
||||||
|
D: feature dimension
|
||||||
|
M: context sequence length
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: Optional[int],
|
||||||
|
num_heads: int,
|
||||||
|
bias: bool = False,
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.x_format = x_format
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
x_dim // num_heads,
|
||||||
|
qkv_bias=bias,
|
||||||
|
qkv_norm="RRI",
|
||||||
|
out_bias=bias,
|
||||||
|
qkv_norm_mode=qkv_norm_mode,
|
||||||
|
qkv_format="sbhd",
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for video attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
|
||||||
|
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
|
||||||
|
where M is the sequence length of the context.
|
||||||
|
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
|
||||||
|
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||||
|
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor with applied attention, maintaining the input shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
x_T_H_W_B_D = x
|
||||||
|
context_M_B_D = context
|
||||||
|
T, H, W, B, D = x_T_H_W_B_D.shape
|
||||||
|
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
|
||||||
|
x_THW_B_D = self.attn(
|
||||||
|
x_THW_B_D,
|
||||||
|
context_M_B_D,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||||
|
return x_T_H_W_B_D
|
||||||
|
|
||||||
|
|
||||||
|
def adaln_norm_state(norm_state, x, scale, shift):
|
||||||
|
normalized = norm_state(x)
|
||||||
|
return normalized * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
|
class DITBuildingBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
|
||||||
|
attention and MLP operations with adaptive layer normalization.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
block_type (str): Type of block - one of:
|
||||||
|
- "cross_attn"/"ca": Cross-attention
|
||||||
|
- "full_attn"/"fa": Full self-attention
|
||||||
|
- "mlp"/"ff": MLP/feedforward block
|
||||||
|
x_dim (int): Dimension of input features
|
||||||
|
context_dim (Optional[int]): Dimension of context features for cross-attention
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||||
|
bias (bool): Whether to use bias in layers. Default: False
|
||||||
|
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
|
||||||
|
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
|
||||||
|
x_format (str): Input tensor format. Default: "BTHWD"
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_type: str,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: Optional[int],
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
bias: bool = False,
|
||||||
|
mlp_dropout: float = 0.0,
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None
|
||||||
|
) -> None:
|
||||||
|
block_type = block_type.lower()
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.x_format = x_format
|
||||||
|
if block_type in ["cross_attn", "ca"]:
|
||||||
|
self.block = VideoAttn(
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
bias=bias,
|
||||||
|
qkv_norm_mode=qkv_norm_mode,
|
||||||
|
x_format=self.x_format,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
elif block_type in ["full_attn", "fa"]:
|
||||||
|
self.block = VideoAttn(
|
||||||
|
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
|
||||||
|
)
|
||||||
|
elif block_type in ["mlp", "ff"]:
|
||||||
|
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown block type: {block_type}")
|
||||||
|
|
||||||
|
self.block_type = block_type
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
|
||||||
|
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.n_adaln_chunks = 3
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
|
||||||
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
|
||||||
|
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
|
||||||
|
crossattn_emb (Tensor): Tensor for cross-attention blocks.
|
||||||
|
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
|
||||||
|
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||||
|
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor after processing through the configured block and adaptive normalization.
|
||||||
|
"""
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
|
||||||
|
self.n_adaln_chunks, dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
|
||||||
|
|
||||||
|
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
|
||||||
|
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.block_type in ["mlp", "ff"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
)
|
||||||
|
elif self.block_type in ["full_attn", "fa"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
context=None,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
elif self.block_type in ["cross_attn", "ca"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
context=crossattn_emb,
|
||||||
|
crossattn_mask=crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralDITTransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
|
||||||
|
Each block in the sequence is specified by a block configuration string.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_dim (int): Dimension of input features
|
||||||
|
context_dim (int): Dimension of context features for cross-attention blocks
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
|
||||||
|
full-attention, then MLP)
|
||||||
|
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||||
|
x_format (str): Input tensor format. Default: "BTHWD"
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||||
|
|
||||||
|
The block_config string uses "-" to separate block types:
|
||||||
|
- "ca"/"cross_attn": Cross-attention block
|
||||||
|
- "fa"/"full_attn": Full self-attention block
|
||||||
|
- "mlp"/"ff": MLP/feedforward block
|
||||||
|
|
||||||
|
Example:
|
||||||
|
block_config = "ca-fa-mlp" creates a sequence of:
|
||||||
|
1. Cross-attention block
|
||||||
|
2. Full self-attention block
|
||||||
|
3. MLP block
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
block_config: str,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
self.x_format = x_format
|
||||||
|
for block_type in block_config.split("-"):
|
||||||
|
self.blocks.append(
|
||||||
|
DITBuildingBlock(
|
||||||
|
block_type,
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio,
|
||||||
|
x_format=self.x_format,
|
||||||
|
use_adaln_lora=use_adaln_lora,
|
||||||
|
adaln_lora_dim=adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
emb_B_D,
|
||||||
|
crossattn_emb,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
|
return x
|
||||||
1041
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
1041
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
File diff suppressed because it is too large
Load Diff
377
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
377
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""The patcher and unpatcher implementation for 2D and 3D data.
|
||||||
|
|
||||||
|
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
|
||||||
|
One on the rows and one on the columns.
|
||||||
|
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
|
||||||
|
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
|
||||||
|
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
|
||||||
|
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
|
||||||
|
as we need to support downsampling for more than 2x.
|
||||||
|
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
|
||||||
|
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
_WAVELETS = {
|
||||||
|
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
|
||||||
|
"rearrange": torch.tensor([1.0, 1.0]),
|
||||||
|
}
|
||||||
|
_PERSISTENT = False
|
||||||
|
|
||||||
|
|
||||||
|
class Patcher(torch.nn.Module):
|
||||||
|
"""A module to convert image tensors into patches using torch operations.
|
||||||
|
|
||||||
|
The main difference from `class Patching` is that this module implements
|
||||||
|
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||||
|
|
||||||
|
It's bit-wise identical to the Patching module outputs, with the added
|
||||||
|
benefit of being torch.jit scriptable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_method = patch_method
|
||||||
|
self.register_buffer(
|
||||||
|
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||||
|
)
|
||||||
|
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||||
|
self.register_buffer(
|
||||||
|
"_arange",
|
||||||
|
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.patch_method == "haar":
|
||||||
|
return self._haar(x)
|
||||||
|
elif self.patch_method == "rearrange":
|
||||||
|
return self._arrange(x)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||||
|
|
||||||
|
def _dwt(self, x, mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
n = h.shape[0]
|
||||||
|
g = x.shape[1]
|
||||||
|
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
||||||
|
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
|
||||||
|
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
|
||||||
|
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
|
||||||
|
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
|
||||||
|
if rescale:
|
||||||
|
out = out / 2
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _haar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._dwt(x, rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _arrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (h p1) (w p2) -> b (c p1 p2) h w",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Patcher3D(Patcher):
|
||||||
|
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||||
|
self.register_buffer(
|
||||||
|
"patch_size_buffer",
|
||||||
|
patch_size * torch.ones([1], dtype=torch.int32),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
n = h.shape[0]
|
||||||
|
g = x.shape[1]
|
||||||
|
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
# Handles temporal axis.
|
||||||
|
x = F.pad(
|
||||||
|
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
|
||||||
|
).to(dtype)
|
||||||
|
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||||
|
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||||
|
|
||||||
|
# Handles spatial axes.
|
||||||
|
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
|
||||||
|
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
|
||||||
|
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
|
||||||
|
if rescale:
|
||||||
|
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _haar(self, x):
|
||||||
|
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||||
|
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._dwt(x, "haar", rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _arrange(self, x):
|
||||||
|
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||||
|
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
p3=self.patch_size,
|
||||||
|
).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UnPatcher(torch.nn.Module):
|
||||||
|
"""A module to convert patches into image tensorsusing torch operations.
|
||||||
|
|
||||||
|
The main difference from `class Unpatching` is that this module implements
|
||||||
|
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||||
|
|
||||||
|
It's bit-wise identical to the Unpatching module outputs, with the added
|
||||||
|
benefit of being torch.jit scriptable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_method = patch_method
|
||||||
|
self.register_buffer(
|
||||||
|
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||||
|
)
|
||||||
|
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||||
|
self.register_buffer(
|
||||||
|
"_arange",
|
||||||
|
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.patch_method == "haar":
|
||||||
|
return self._ihaar(x)
|
||||||
|
elif self.patch_method == "rearrange":
|
||||||
|
return self._iarrange(x)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||||
|
|
||||||
|
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
n = h.shape[0]
|
||||||
|
|
||||||
|
g = x.shape[1] // 4
|
||||||
|
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
|
||||||
|
|
||||||
|
# Inverse transform.
|
||||||
|
yl = torch.nn.functional.conv_transpose2d(
|
||||||
|
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yl += torch.nn.functional.conv_transpose2d(
|
||||||
|
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yh = torch.nn.functional.conv_transpose2d(
|
||||||
|
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yh += torch.nn.functional.conv_transpose2d(
|
||||||
|
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
y = torch.nn.functional.conv_transpose2d(
|
||||||
|
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||||
|
)
|
||||||
|
y += torch.nn.functional.conv_transpose2d(
|
||||||
|
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale:
|
||||||
|
y = y * 2
|
||||||
|
return y
|
||||||
|
|
||||||
|
def _ihaar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._idwt(x, "haar", rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _iarrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UnPatcher3D(UnPatcher):
|
||||||
|
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||||
|
|
||||||
|
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
|
||||||
|
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
|
||||||
|
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||||
|
del x
|
||||||
|
|
||||||
|
# Height height transposed convolutions.
|
||||||
|
xll = F.conv_transpose3d(
|
||||||
|
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xlll
|
||||||
|
|
||||||
|
xll += F.conv_transpose3d(
|
||||||
|
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xllh
|
||||||
|
|
||||||
|
xlh = F.conv_transpose3d(
|
||||||
|
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xlhl
|
||||||
|
|
||||||
|
xlh += F.conv_transpose3d(
|
||||||
|
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xlhh
|
||||||
|
|
||||||
|
xhl = F.conv_transpose3d(
|
||||||
|
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xhll
|
||||||
|
|
||||||
|
xhl += F.conv_transpose3d(
|
||||||
|
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xhlh
|
||||||
|
|
||||||
|
xhh = F.conv_transpose3d(
|
||||||
|
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xhhl
|
||||||
|
|
||||||
|
xhh += F.conv_transpose3d(
|
||||||
|
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xhhh
|
||||||
|
|
||||||
|
# Handles width transposed convolutions.
|
||||||
|
xl = F.conv_transpose3d(
|
||||||
|
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
del xll
|
||||||
|
|
||||||
|
xl += F.conv_transpose3d(
|
||||||
|
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
del xlh
|
||||||
|
|
||||||
|
xh = F.conv_transpose3d(
|
||||||
|
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
del xhl
|
||||||
|
|
||||||
|
xh += F.conv_transpose3d(
|
||||||
|
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
del xhh
|
||||||
|
|
||||||
|
# Handles time axis transposed convolutions.
|
||||||
|
x = F.conv_transpose3d(
|
||||||
|
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
|
)
|
||||||
|
del xl
|
||||||
|
|
||||||
|
x += F.conv_transpose3d(
|
||||||
|
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale:
|
||||||
|
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _ihaar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._idwt(x, "haar", rescale=True)
|
||||||
|
x = x[:, :, self.patch_size - 1 :, ...]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _iarrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
p3=self.patch_size,
|
||||||
|
)
|
||||||
|
x = x[:, :, self.patch_size - 1 :, ...]
|
||||||
|
return x
|
||||||
112
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
112
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Shared utilities for the networks module."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||||
|
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||||
|
batch_size, height = x.shape[0], x.shape[-2]
|
||||||
|
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
|
||||||
|
|
||||||
|
|
||||||
|
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
|
||||||
|
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(t: Any, length: int = 1) -> Any:
|
||||||
|
return t if isinstance(t, tuple) else ((t,) * length)
|
||||||
|
|
||||||
|
|
||||||
|
def replication_pad(x):
|
||||||
|
return torch.cat([x[:, :, :1, ...], x], dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
def divisible_by(num: int, den: int) -> bool:
|
||||||
|
return (num % den) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_odd(n: int) -> bool:
|
||||||
|
return not divisible_by(n, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def nonlinearity(x):
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
return ops.GroupNorm(
|
||||||
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalNormalize(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, num_groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = ops.GroupNorm(
|
||||||
|
num_groups=num_groups,
|
||||||
|
num_channels=in_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True,
|
||||||
|
)
|
||||||
|
self.num_groups = num_groups
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
|
||||||
|
# All new models should use num_groups=1, otherwise causality is not guaranteed.
|
||||||
|
if self.num_groups == 1:
|
||||||
|
x, batch_size = time2batch(x)
|
||||||
|
return batch2time(self.norm(x), batch_size)
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
def exists(v):
|
||||||
|
return v is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(*args):
|
||||||
|
for arg in args:
|
||||||
|
if exists(arg):
|
||||||
|
return arg
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def round_ste(z: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Round with straight through gradients."""
|
||||||
|
zhat = z.round()
|
||||||
|
return z + (zhat - z).detach()
|
||||||
|
|
||||||
|
|
||||||
|
def log(t, eps=1e-5):
|
||||||
|
return t.clamp(min=eps).log()
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(prob):
|
||||||
|
return (-prob * log(prob)).sum(dim=-1)
|
||||||
514
comfy/ldm/cosmos/model.py
Normal file
514
comfy/ldm/cosmos/model.py
Normal file
@@ -0,0 +1,514 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||||
|
|
||||||
|
from .blocks import (
|
||||||
|
FinalLayer,
|
||||||
|
GeneralDITTransformerBlock,
|
||||||
|
PatchEmbed,
|
||||||
|
TimestepEmbedding,
|
||||||
|
Timesteps,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(Enum):
|
||||||
|
IMAGE = "image"
|
||||||
|
VIDEO = "video"
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralDIT(nn.Module):
|
||||||
|
"""
|
||||||
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_img_h (int): Maximum height of the input images.
|
||||||
|
max_img_w (int): Maximum width of the input images.
|
||||||
|
max_frames (int): Maximum number of frames in the video sequence.
|
||||||
|
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||||
|
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||||
|
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||||
|
block_config (str): Configuration of the transformer block. See Notes for supported block types.
|
||||||
|
model_channels (int): Base number of channels used throughout the model.
|
||||||
|
num_blocks (int): Number of transformer blocks.
|
||||||
|
num_heads (int): Number of heads in the multi-head attention layers.
|
||||||
|
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||||
|
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
|
||||||
|
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||||
|
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
|
||||||
|
pos_emb_cls (str): Type of positional embeddings.
|
||||||
|
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||||
|
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||||
|
affline_emb_norm (bool): Whether to normalize affine embeddings.
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||||
|
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||||
|
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||||
|
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||||
|
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||||
|
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
|
||||||
|
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||||
|
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||||
|
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Supported block types in block_config:
|
||||||
|
* cross_attn, ca: Cross attention
|
||||||
|
* full_attn: Full attention on all flattened tokens
|
||||||
|
* mlp, ff: Feed forward block
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_img_h: int,
|
||||||
|
max_img_w: int,
|
||||||
|
max_frames: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
patch_spatial: tuple,
|
||||||
|
patch_temporal: int,
|
||||||
|
concat_padding_mask: bool = True,
|
||||||
|
# attention settings
|
||||||
|
block_config: str = "FA-CA-MLP",
|
||||||
|
model_channels: int = 768,
|
||||||
|
num_blocks: int = 10,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
block_x_format: str = "BTHWD",
|
||||||
|
# cross attention settings
|
||||||
|
crossattn_emb_channels: int = 1024,
|
||||||
|
use_cross_attn_mask: bool = False,
|
||||||
|
# positional embedding settings
|
||||||
|
pos_emb_cls: str = "sincos",
|
||||||
|
pos_emb_learnable: bool = False,
|
||||||
|
pos_emb_interpolation: str = "crop",
|
||||||
|
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
rope_h_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_w_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_t_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_per_block_abs_pos_emb: bool = False,
|
||||||
|
extra_per_block_abs_pos_emb_type: str = "sincos",
|
||||||
|
extra_h_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_w_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_t_extrapolation_ratio: float = 1.0,
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.max_img_h = max_img_h
|
||||||
|
self.max_img_w = max_img_w
|
||||||
|
self.max_frames = max_frames
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_spatial = patch_spatial
|
||||||
|
self.patch_temporal = patch_temporal
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.use_cross_attn_mask = use_cross_attn_mask
|
||||||
|
self.concat_padding_mask = concat_padding_mask
|
||||||
|
# positional embedding settings
|
||||||
|
self.pos_emb_cls = pos_emb_cls
|
||||||
|
self.pos_emb_learnable = pos_emb_learnable
|
||||||
|
self.pos_emb_interpolation = pos_emb_interpolation
|
||||||
|
self.affline_emb_norm = affline_emb_norm
|
||||||
|
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||||
|
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||||
|
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||||
|
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||||
|
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
|
||||||
|
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||||
|
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||||
|
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||||
|
self.dtype = dtype
|
||||||
|
weight_args = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
spatial_patch_size=patch_spatial,
|
||||||
|
temporal_patch_size=patch_temporal,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=model_channels,
|
||||||
|
bias=False,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.build_pos_embed(device=device, dtype=dtype)
|
||||||
|
self.block_x_format = block_x_format
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
self.adaln_lora_dim = adaln_lora_dim
|
||||||
|
self.t_embedder = nn.ModuleList(
|
||||||
|
[Timesteps(model_channels),
|
||||||
|
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleDict()
|
||||||
|
|
||||||
|
for idx in range(num_blocks):
|
||||||
|
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
|
||||||
|
x_dim=model_channels,
|
||||||
|
context_dim=crossattn_emb_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
block_config=block_config,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
x_format=self.block_x_format,
|
||||||
|
use_adaln_lora=use_adaln_lora,
|
||||||
|
adaln_lora_dim=adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.affline_emb_norm:
|
||||||
|
logging.debug("Building affine embedding normalization layer")
|
||||||
|
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
else:
|
||||||
|
self.affline_norm = nn.Identity()
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
hidden_size=self.model_channels,
|
||||||
|
spatial_patch_size=self.patch_spatial,
|
||||||
|
temporal_patch_size=self.patch_temporal,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
use_adaln_lora=self.use_adaln_lora,
|
||||||
|
adaln_lora_dim=self.adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_pos_embed(self, device=None, dtype=None):
|
||||||
|
if self.pos_emb_cls == "rope3d":
|
||||||
|
cls_type = VideoRopePosition3DEmb
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||||
|
|
||||||
|
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||||
|
kwargs = dict(
|
||||||
|
model_channels=self.model_channels,
|
||||||
|
len_h=self.max_img_h // self.patch_spatial,
|
||||||
|
len_w=self.max_img_w // self.patch_spatial,
|
||||||
|
len_t=self.max_frames // self.patch_temporal,
|
||||||
|
is_learnable=self.pos_emb_learnable,
|
||||||
|
interpolation=self.pos_emb_interpolation,
|
||||||
|
head_dim=self.model_channels // self.num_heads,
|
||||||
|
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||||
|
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||||
|
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.pos_embedder = cls_type(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.extra_per_block_abs_pos_emb:
|
||||||
|
assert self.extra_per_block_abs_pos_emb_type in [
|
||||||
|
"learnable",
|
||||||
|
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
|
||||||
|
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||||
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||||
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||||
|
kwargs["device"] = device
|
||||||
|
kwargs["dtype"] = dtype
|
||||||
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_embedded_sequence(
|
||||||
|
self,
|
||||||
|
x_B_C_T_H_W: torch.Tensor,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_B_C_T_H_W (torch.Tensor): video
|
||||||
|
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||||
|
If None, a default value (`self.base_fps`) will be used.
|
||||||
|
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||||
|
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||||
|
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||||
|
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||||
|
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||||
|
the `self.pos_embedder` with the shape [T, H, W].
|
||||||
|
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||||
|
`self.pos_embedder` with the fps tensor.
|
||||||
|
- Otherwise, the positional embeddings are generated without considering fps.
|
||||||
|
"""
|
||||||
|
if self.concat_padding_mask:
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = transforms.functional.resize(
|
||||||
|
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||||
|
|
||||||
|
x_B_C_T_H_W = torch.cat(
|
||||||
|
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||||
|
)
|
||||||
|
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||||
|
|
||||||
|
if self.extra_per_block_abs_pos_emb:
|
||||||
|
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||||
|
else:
|
||||||
|
extra_pos_emb = None
|
||||||
|
|
||||||
|
if "rope" in self.pos_emb_cls.lower():
|
||||||
|
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||||
|
|
||||||
|
if "fps_aware" in self.pos_emb_cls:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||||
|
else:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||||
|
|
||||||
|
return x_B_T_H_W_D, None, extra_pos_emb
|
||||||
|
|
||||||
|
def decoder_head(
|
||||||
|
self,
|
||||||
|
x_B_T_H_W_D: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
del crossattn_emb, crossattn_mask
|
||||||
|
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
|
||||||
|
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
|
||||||
|
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
|
||||||
|
# This is to ensure x_BT_HW_D has the correct shape because
|
||||||
|
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
|
||||||
|
x_BT_HW_D = x_BT_HW_D.view(
|
||||||
|
B * T_before_patchify // self.patch_temporal,
|
||||||
|
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
x_B_D_T_H_W = rearrange(
|
||||||
|
x_BT_HW_D,
|
||||||
|
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||||
|
p1=self.patch_spatial,
|
||||||
|
p2=self.patch_spatial,
|
||||||
|
H=H_before_patchify // self.patch_spatial,
|
||||||
|
W=W_before_patchify // self.patch_spatial,
|
||||||
|
t=self.patch_temporal,
|
||||||
|
B=B,
|
||||||
|
)
|
||||||
|
return x_B_D_T_H_W
|
||||||
|
|
||||||
|
def forward_before_blocks(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
image_size: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
scalar_feature: Optional[torch.Tensor] = None,
|
||||||
|
data_type: Optional[DataType] = DataType.VIDEO,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||||
|
timesteps: (B, ) tensor of timesteps
|
||||||
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||||
|
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||||
|
"""
|
||||||
|
del kwargs
|
||||||
|
assert isinstance(
|
||||||
|
data_type, DataType
|
||||||
|
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
|
||||||
|
original_shape = x.shape
|
||||||
|
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||||
|
x,
|
||||||
|
fps=fps,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
latent_condition=latent_condition,
|
||||||
|
latent_condition_sigma=latent_condition_sigma,
|
||||||
|
)
|
||||||
|
# logging affline scale information
|
||||||
|
affline_scale_log_info = {}
|
||||||
|
|
||||||
|
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
|
||||||
|
affline_emb_B_D = timesteps_B_D
|
||||||
|
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
|
||||||
|
|
||||||
|
if scalar_feature is not None:
|
||||||
|
raise NotImplementedError("Scalar feature is not implemented yet.")
|
||||||
|
|
||||||
|
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
|
||||||
|
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
|
||||||
|
|
||||||
|
if self.use_cross_attn_mask:
|
||||||
|
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
|
||||||
|
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||||
|
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
|
||||||
|
else:
|
||||||
|
crossattn_mask = None
|
||||||
|
|
||||||
|
if self.blocks["block0"].x_format == "THWBD":
|
||||||
|
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
|
||||||
|
)
|
||||||
|
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
|
||||||
|
|
||||||
|
if crossattn_mask:
|
||||||
|
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
|
||||||
|
|
||||||
|
elif self.blocks["block0"].x_format == "BTHWD":
|
||||||
|
x = x_B_T_H_W_D
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
|
||||||
|
output = {
|
||||||
|
"x": x,
|
||||||
|
"affline_emb_B_D": affline_emb_B_D,
|
||||||
|
"crossattn_emb": crossattn_emb,
|
||||||
|
"crossattn_mask": crossattn_mask,
|
||||||
|
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
|
||||||
|
"adaln_lora_B_3D": adaln_lora_B_3D,
|
||||||
|
"original_shape": original_shape,
|
||||||
|
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
# crossattn_emb: torch.Tensor,
|
||||||
|
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
image_size: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
scalar_feature: Optional[torch.Tensor] = None,
|
||||||
|
data_type: Optional[DataType] = DataType.VIDEO,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||||
|
timesteps: (B, ) tensor of timesteps
|
||||||
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||||
|
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||||
|
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
|
||||||
|
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
|
||||||
|
we need forward_before_blocks pass to the forward_before_blocks function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
crossattn_emb = context
|
||||||
|
crossattn_mask = attention_mask
|
||||||
|
|
||||||
|
inputs = self.forward_before_blocks(
|
||||||
|
x=x,
|
||||||
|
timesteps=timesteps,
|
||||||
|
crossattn_emb=crossattn_emb,
|
||||||
|
crossattn_mask=crossattn_mask,
|
||||||
|
fps=fps,
|
||||||
|
image_size=image_size,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
scalar_feature=scalar_feature,
|
||||||
|
data_type=data_type,
|
||||||
|
latent_condition=latent_condition,
|
||||||
|
latent_condition_sigma=latent_condition_sigma,
|
||||||
|
condition_video_augment_sigma=condition_video_augment_sigma,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
|
||||||
|
inputs["x"],
|
||||||
|
inputs["affline_emb_B_D"],
|
||||||
|
inputs["crossattn_emb"],
|
||||||
|
inputs["crossattn_mask"],
|
||||||
|
inputs["rope_emb_L_1_1_D"],
|
||||||
|
inputs["adaln_lora_B_3D"],
|
||||||
|
inputs["original_shape"],
|
||||||
|
)
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||||
|
del inputs
|
||||||
|
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
assert (
|
||||||
|
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
|
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||||
|
|
||||||
|
for _, block in self.blocks.items():
|
||||||
|
assert (
|
||||||
|
self.blocks["block0"].x_format == block.x_format
|
||||||
|
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||||
|
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
affline_emb_B_D,
|
||||||
|
crossattn_emb,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||||
|
|
||||||
|
x_B_D_T_H_W = self.decoder_head(
|
||||||
|
x_B_T_H_W_D=x_B_T_H_W_D,
|
||||||
|
emb_B_D=affline_emb_B_D,
|
||||||
|
crossattn_emb=None,
|
||||||
|
origin_shape=original_shape,
|
||||||
|
crossattn_mask=None,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
|
|
||||||
|
return x_B_D_T_H_W
|
||||||
208
comfy/ldm/cosmos/position_embedding.py
Normal file
208
comfy/ldm/cosmos/position_embedding.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor to normalize.
|
||||||
|
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
||||||
|
eps (float, optional): A small constant to ensure numerical stability during division.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
"""
|
||||||
|
if dim is None:
|
||||||
|
dim = list(range(1, x.ndim))
|
||||||
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||||
|
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||||
|
return x / norm.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPositionEmb(nn.Module):
|
||||||
|
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
It delegates the embedding generation to generate_embeddings function.
|
||||||
|
"""
|
||||||
|
B_T_H_W_C = x_B_T_H_W_C.shape
|
||||||
|
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*, # enforce keyword arguments
|
||||||
|
head_dim: int,
|
||||||
|
len_h: int,
|
||||||
|
len_w: int,
|
||||||
|
len_t: int,
|
||||||
|
base_fps: int = 24,
|
||||||
|
h_extrapolation_ratio: float = 1.0,
|
||||||
|
w_extrapolation_ratio: float = 1.0,
|
||||||
|
t_extrapolation_ratio: float = 1.0,
|
||||||
|
device=None,
|
||||||
|
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||||
|
):
|
||||||
|
del kwargs
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
||||||
|
self.base_fps = base_fps
|
||||||
|
self.max_h = len_h
|
||||||
|
self.max_w = len_w
|
||||||
|
|
||||||
|
dim = head_dim
|
||||||
|
dim_h = dim // 6 * 2
|
||||||
|
dim_w = dim_h
|
||||||
|
dim_t = dim - 2 * dim_h
|
||||||
|
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
||||||
|
self.register_buffer(
|
||||||
|
"dim_spatial_range",
|
||||||
|
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"dim_temporal_range",
|
||||||
|
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
||||||
|
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
||||||
|
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self,
|
||||||
|
B_T_H_W_C: torch.Size,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
h_ntk_factor: Optional[float] = None,
|
||||||
|
w_ntk_factor: Optional[float] = None,
|
||||||
|
t_ntk_factor: Optional[float] = None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate embeddings for the given input size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
||||||
|
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
||||||
|
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
||||||
|
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
||||||
|
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Not specified in the original code snippet.
|
||||||
|
"""
|
||||||
|
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
||||||
|
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
||||||
|
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
||||||
|
|
||||||
|
h_theta = 10000.0 * h_ntk_factor
|
||||||
|
w_theta = 10000.0 * w_ntk_factor
|
||||||
|
t_theta = 10000.0 * t_ntk_factor
|
||||||
|
|
||||||
|
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
||||||
|
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
||||||
|
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||||
|
|
||||||
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
|
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||||
|
assert (
|
||||||
|
uniform_fps or B == 1 or T == 1
|
||||||
|
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||||
|
assert (
|
||||||
|
H <= self.max_h and W <= self.max_w
|
||||||
|
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
||||||
|
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
||||||
|
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||||
|
|
||||||
|
# apply sequence scaling in temporal dimension
|
||||||
|
if fps is None: # image case
|
||||||
|
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||||
|
else:
|
||||||
|
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||||
|
|
||||||
|
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||||
|
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||||
|
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
||||||
|
|
||||||
|
em_T_H_W_D = torch.cat(
|
||||||
|
[
|
||||||
|
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
||||||
|
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
||||||
|
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
||||||
|
]
|
||||||
|
, dim=-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
||||||
|
|
||||||
|
|
||||||
|
class LearnablePosEmbAxis(VideoPositionEmb):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*, # enforce keyword arguments
|
||||||
|
interpolation: str,
|
||||||
|
model_channels: int,
|
||||||
|
len_h: int,
|
||||||
|
len_w: int,
|
||||||
|
len_t: int,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
||||||
|
"""
|
||||||
|
del kwargs # unused
|
||||||
|
super().__init__()
|
||||||
|
self.interpolation = interpolation
|
||||||
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||||
|
|
||||||
|
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
||||||
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||||
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||||
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
|
if self.interpolation == "crop":
|
||||||
|
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
|
||||||
|
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
|
||||||
|
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
|
||||||
|
emb = (
|
||||||
|
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
||||||
|
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
||||||
|
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
||||||
|
)
|
||||||
|
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
||||||
|
|
||||||
|
return normalize(emb, dim=-1, eps=1e-6)
|
||||||
131
comfy/ldm/cosmos/vae.py
Normal file
131
comfy/ldm/cosmos/vae.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from enum import Enum
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .cosmos_tokenizer.layers3d import (
|
||||||
|
EncoderFactorized,
|
||||||
|
DecoderFactorized,
|
||||||
|
CausalConv3d,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityDistribution(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, parameters):
|
||||||
|
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianDistribution(torch.nn.Module):
|
||||||
|
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
|
||||||
|
super().__init__()
|
||||||
|
self.min_logvar = min_logvar
|
||||||
|
self.max_logvar = max_logvar
|
||||||
|
|
||||||
|
def sample(self, mean, logvar):
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
return mean + std * torch.randn_like(mean)
|
||||||
|
|
||||||
|
def forward(self, parameters):
|
||||||
|
mean, logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
|
||||||
|
return self.sample(mean, logvar), (mean, logvar)
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousFormulation(Enum):
|
||||||
|
VAE = GaussianDistribution
|
||||||
|
AE = IdentityDistribution
|
||||||
|
|
||||||
|
|
||||||
|
class CausalContinuousVideoTokenizer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
self.sigma_data = 0.5
|
||||||
|
|
||||||
|
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
|
||||||
|
self.encoder = EncoderFactorized(
|
||||||
|
z_channels=z_factor * z_channels, **kwargs
|
||||||
|
)
|
||||||
|
if kwargs.get("temporal_compression", 4) == 4:
|
||||||
|
kwargs["channels_mult"] = [2, 4]
|
||||||
|
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
|
||||||
|
self.decoder = DecoderFactorized(
|
||||||
|
z_channels=z_channels, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = CausalConv3d(
|
||||||
|
z_factor * z_channels,
|
||||||
|
z_factor * latent_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.post_quant_conv = CausalConv3d(
|
||||||
|
latent_channels, z_channels, kernel_size=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
|
||||||
|
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
||||||
|
|
||||||
|
num_parameters = sum(param.numel() for param in self.parameters())
|
||||||
|
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||||
|
logging.debug(
|
||||||
|
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_temporal_chunk = 16
|
||||||
|
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||||
|
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
z, posteriors = self.distribution(moments)
|
||||||
|
latent_ch = z.shape[1]
|
||||||
|
latent_t = z.shape[2]
|
||||||
|
in_dtype = z.dtype
|
||||||
|
mean = self.latent_mean.view(latent_ch, -1)
|
||||||
|
std = self.latent_std.view(latent_ch, -1)
|
||||||
|
|
||||||
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
return ((z - mean) / std) * self.sigma_data
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
in_dtype = z.dtype
|
||||||
|
latent_ch = z.shape[1]
|
||||||
|
latent_t = z.shape[2]
|
||||||
|
mean = self.latent_mean.view(latent_ch, -1)
|
||||||
|
std = self.latent_std.view(latent_ch, -1)
|
||||||
|
|
||||||
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
|
||||||
|
z = z / self.sigma_data
|
||||||
|
z = z * std + mean
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
@@ -6,9 +6,7 @@ import math
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
from .layers import (timestep_embedding)
|
||||||
MLPEmbedder, SingleStreamBlock,
|
|
||||||
timestep_embedding)
|
|
||||||
|
|
||||||
from .model import Flux
|
from .model import Flux
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class Modulation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@@ -141,8 +141,9 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
nn.GELU(approximate="tanh"),
|
nn.GELU(approximate="tanh"),
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
@@ -160,12 +161,22 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
# run actual attention
|
if self.flipped_img_txt:
|
||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
# run actual attention
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||||
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
torch.cat((img_k, txt_k), dim=2),
|
||||||
|
torch.cat((img_v, txt_v), dim=2),
|
||||||
|
pe=pe, mask=attn_mask)
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||||
|
else:
|
||||||
|
# run actual attention
|
||||||
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
|
torch.cat((txt_v, img_v), dim=2),
|
||||||
|
pe=pe, mask=attn_mask)
|
||||||
|
|
||||||
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
@@ -217,16 +228,15 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x += mod.gate * output
|
||||||
|
|||||||
@@ -1,14 +1,22 @@
|
|||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
|
||||||
q, k = apply_rope(q, k, pe)
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||||
|
q_shape = q.shape
|
||||||
|
k_shape = k.shape
|
||||||
|
|
||||||
|
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||||
|
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||||
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||||
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -33,3 +41,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
|||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
DoubleStreamBlock,
|
||||||
@@ -14,9 +16,6 @@ from .layers import (
|
|||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FluxParams:
|
class FluxParams:
|
||||||
in_channels: int
|
in_channels: int
|
||||||
@@ -98,8 +97,9 @@ class Flux(nn.Module):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
control=None,
|
control = None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
@@ -109,9 +109,8 @@ class Flux(nn.Module):
|
|||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is None:
|
if guidance is not None:
|
||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
|
||||||
|
|
||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
@@ -124,14 +123,27 @@ class Flux(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
out["img"], out["txt"] = block(img=args["img"],
|
||||||
|
txt=args["txt"],
|
||||||
|
vec=args["vec"],
|
||||||
|
pe=args["pe"],
|
||||||
|
attn_mask=args.get("attn_mask"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
|
"txt": txt,
|
||||||
|
"vec": vec,
|
||||||
|
"pe": pe,
|
||||||
|
"attn_mask": attn_mask},
|
||||||
|
{"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = block(img=img,
|
||||||
|
txt=txt,
|
||||||
|
vec=vec,
|
||||||
|
pe=pe,
|
||||||
|
attn_mask=attn_mask)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@@ -146,13 +158,20 @@ class Flux(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
out["img"] = block(args["img"],
|
||||||
|
vec=args["vec"],
|
||||||
|
pe=args["pe"],
|
||||||
|
attn_mask=args.get("attn_mask"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
|
"vec": vec,
|
||||||
|
"pe": pe,
|
||||||
|
"attn_mask": attn_mask},
|
||||||
|
{"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
@@ -166,7 +185,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
@@ -181,5 +200,5 @@ class Flux(nn.Module):
|
|||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||||
|
|||||||
@@ -461,8 +461,6 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
pH, pW = H // self.patch_size, W // self.patch_size
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||||
assert x.ndim == 3
|
assert x.ndim == 3
|
||||||
B = x.size(0)
|
|
||||||
|
|
||||||
|
|
||||||
pH, pW = H // self.patch_size, W // self.patch_size
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
N = T * pH * pW
|
N = T * pH * pW
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
#adapted to ComfyUI
|
#adapted to ComfyUI
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
#adapted to ComfyUI
|
#adapted to ComfyUI
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|||||||
329
comfy/ldm/hunyuan_video/model.py
Normal file
329
comfy/ldm/hunyuan_video/model.py
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
#Based on Flux code because of weird hunyuan video code license.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.ldm.flux.layers
|
||||||
|
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from comfy.ldm.flux.layers import (
|
||||||
|
DoubleStreamBlock,
|
||||||
|
EmbedND,
|
||||||
|
LastLayer,
|
||||||
|
MLPEmbedder,
|
||||||
|
SingleStreamBlock,
|
||||||
|
timestep_embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HunyuanVideoParams:
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
vec_in_dim: int
|
||||||
|
context_in_dim: int
|
||||||
|
hidden_size: int
|
||||||
|
mlp_ratio: float
|
||||||
|
num_heads: int
|
||||||
|
depth: int
|
||||||
|
depth_single_blocks: int
|
||||||
|
axes_dim: list
|
||||||
|
theta: int
|
||||||
|
patch_size: list
|
||||||
|
qkv_bias: bool
|
||||||
|
guidance_embed: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionRef(nn.Module):
|
||||||
|
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenRefinerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
heads,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
mlp_hidden_dim = hidden_size * 4
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c, mask):
|
||||||
|
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
|
||||||
|
norm_x = self.norm1(x)
|
||||||
|
qkv = self.self_attn.qkv(norm_x)
|
||||||
|
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||||
|
|
||||||
|
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||||
|
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualTokenRefiner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
heads,
|
||||||
|
num_blocks,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TokenRefinerBlock(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
heads=heads,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(num_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c, mask):
|
||||||
|
m = None
|
||||||
|
if mask is not None:
|
||||||
|
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||||
|
m = m + m.transpose(2, 3)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, c, m)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TokenRefiner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_dim,
|
||||||
|
hidden_size,
|
||||||
|
heads,
|
||||||
|
num_blocks,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
mask,
|
||||||
|
):
|
||||||
|
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||||
|
# m = mask.float().unsqueeze(-1)
|
||||||
|
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||||
|
c = x.sum(dim=1) / x.shape[1]
|
||||||
|
|
||||||
|
c = t + self.c_embedder(c.to(x.dtype))
|
||||||
|
x = self.input_embedder(x)
|
||||||
|
x = self.individual_token_refiner(x, c, mask)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class HunyuanVideo(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer model for flow matching on sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
params = HunyuanVideoParams(**kwargs)
|
||||||
|
self.params = params
|
||||||
|
self.patch_size = params.patch_size
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = params.out_channels
|
||||||
|
if params.hidden_size % params.num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||||
|
)
|
||||||
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
|
if sum(params.axes_dim) != pe_dim:
|
||||||
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||||
|
self.hidden_size = params.hidden_size
|
||||||
|
self.num_heads = params.num_heads
|
||||||
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
|
|
||||||
|
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.guidance_in = (
|
||||||
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
qkv_bias=params.qkv_bias,
|
||||||
|
flipped_img_txt=True,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(params.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(params.depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_layer:
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
txt_mask: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
control=None,
|
||||||
|
transformer_options={},
|
||||||
|
) -> Tensor:
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
|
initial_shape = list(img.shape)
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||||
|
|
||||||
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
if guidance is not None:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
|
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||||
|
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||||
|
|
||||||
|
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||||
|
|
||||||
|
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
img_len = img.shape[1]
|
||||||
|
if txt_mask is not None:
|
||||||
|
attn_mask_len = img_len + txt.shape[1]
|
||||||
|
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||||
|
attn_mask[:, 0, img_len:] = txt_mask
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||||
|
txt = out["txt"]
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_i = control.get("input")
|
||||||
|
if i < len(control_i):
|
||||||
|
add = control_i[i]
|
||||||
|
if add is not None:
|
||||||
|
img += add
|
||||||
|
|
||||||
|
img = torch.cat((img, txt), 1)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.single_blocks):
|
||||||
|
if ("single_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_o = control.get("output")
|
||||||
|
if i < len(control_o):
|
||||||
|
add = control_o[i]
|
||||||
|
if add is not None:
|
||||||
|
img[:, : img_len] += add
|
||||||
|
|
||||||
|
img = img[:, : img_len]
|
||||||
|
|
||||||
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
|
||||||
|
shape = initial_shape[-3:]
|
||||||
|
for i in range(len(shape)):
|
||||||
|
shape[i] = shape[i] // self.patch_size[i]
|
||||||
|
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||||
|
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||||
|
img = img.reshape(initial_shape)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
bs, c, t, h, w = x.shape
|
||||||
|
patch_size = self.patch_size
|
||||||
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||||
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||||
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
||||||
|
return out
|
||||||
@@ -159,7 +159,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
||||||
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
||||||
v = v.transpose(-2, -3).contiguous()
|
v = v.transpose(-2, -3).contiguous()
|
||||||
|
|
||||||
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,17 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from torch.utils import checkpoint
|
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||||
Mlp,
|
|
||||||
TimestepEmbedder,
|
TimestepEmbedder,
|
||||||
PatchEmbed,
|
PatchEmbed,
|
||||||
RMSNorm,
|
|
||||||
)
|
)
|
||||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
|
||||||
from .poolers import AttentionPool
|
from .poolers import AttentionPool
|
||||||
|
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
from .models import HunYuanDiTBlock, calc_rope
|
from .models import HunYuanDiTBlock, calc_rope
|
||||||
|
|
||||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
|
||||||
|
|
||||||
|
|
||||||
class HunYuanControlNet(nn.Module):
|
class HunYuanControlNet(nn.Module):
|
||||||
@@ -171,9 +164,6 @@ class HunYuanControlNet(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Image embedding
|
|
||||||
num_patches = self.x_embedder.num_patches
|
|
||||||
|
|
||||||
# HUnYuanDiT Blocks
|
# HUnYuanDiT Blocks
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||||
@@ -250,9 +248,6 @@ class HunYuanDiT(nn.Module):
|
|||||||
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Image embedding
|
|
||||||
num_patches = self.x_embedder.num_patches
|
|
||||||
|
|
||||||
# HUnYuanDiT Blocks
|
# HUnYuanDiT Blocks
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
|||||||
@@ -379,6 +379,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
positional_embedding_max_pos=[20, 2048, 2048],
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
dtype=None, device=None, operations=None, **kwargs):
|
dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.generator = None
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.out_channels = in_channels
|
self.out_channels = in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
@@ -415,7 +416,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
|
|
||||||
self.patchifier = SymmetricPatchifier(1)
|
self.patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
indices_grid = self.patchifier.get_grid(
|
indices_grid = self.patchifier.get_grid(
|
||||||
@@ -431,10 +432,22 @@ class LTXVModel(torch.nn.Module):
|
|||||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||||
ts *= input_ts
|
ts *= input_ts
|
||||||
ts[:, :, 0] = 0.0
|
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
|
||||||
timestep = self.patchifier.patchify(ts)
|
timestep = self.patchifier.patchify(ts)
|
||||||
input_x = x.clone()
|
input_x = x.clone()
|
||||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||||
|
if guiding_latent_noise_scale > 0:
|
||||||
|
if self.generator is None:
|
||||||
|
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
||||||
|
elif self.generator.device != x.device:
|
||||||
|
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
||||||
|
|
||||||
|
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
||||||
|
scale = guiding_latent_noise_scale * (input_ts ** 2)
|
||||||
|
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
||||||
|
|
||||||
|
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
|
||||||
|
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
orig_shape = list(x.shape)
|
||||||
|
|
||||||
@@ -443,9 +456,8 @@ class LTXVModel(torch.nn.Module):
|
|||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
timestep = timestep * 1000.0
|
timestep = timestep * 1000.0
|
||||||
|
|
||||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class Patchifier(ABC):
|
|||||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
|
||||||
grid = torch.stack(grid, dim=0)
|
grid = torch.stack(grid, dim=0)
|
||||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ from torch import nn
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from typing import Any, Mapping, Optional, Tuple, Union, List
|
from typing import Optional, Tuple, Union
|
||||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||||
from .pixel_norm import PixelNorm
|
from .pixel_norm import PixelNorm
|
||||||
|
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@@ -236,6 +238,7 @@ class Decoder(nn.Module):
|
|||||||
patch_size: int = 1,
|
patch_size: int = 1,
|
||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
@@ -250,6 +253,8 @@ class Decoder(nn.Module):
|
|||||||
block_params = block_params if isinstance(block_params, dict) else {}
|
block_params = block_params if isinstance(block_params, dict) else {}
|
||||||
if block_name == "res_x_y":
|
if block_name == "res_x_y":
|
||||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||||
|
if block_name == "compress_all":
|
||||||
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
|
|
||||||
self.conv_in = make_conv_nd(
|
self.conv_in = make_conv_nd(
|
||||||
dims,
|
dims,
|
||||||
@@ -276,6 +281,19 @@ class Decoder(nn.Module):
|
|||||||
resnet_eps=1e-6,
|
resnet_eps=1e-6,
|
||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
)
|
||||||
|
elif block_name == "attn_res_x":
|
||||||
|
block = UNetMidBlock3D(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
num_layers=block_params["num_layers"],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
attention_head_dim=block_params["attention_head_dim"],
|
||||||
)
|
)
|
||||||
elif block_name == "res_x_y":
|
elif block_name == "res_x_y":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||||
@@ -286,6 +304,8 @@ class Decoder(nn.Module):
|
|||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
groups=norm_num_groups,
|
groups=norm_num_groups,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
|
timestep_conditioning=False,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
@@ -296,11 +316,13 @@ class Decoder(nn.Module):
|
|||||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
|
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
residual=block_params.get("residual", False),
|
residual=block_params.get("residual", False),
|
||||||
|
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown layer: {block_name}")
|
raise ValueError(f"unknown layer: {block_name}")
|
||||||
@@ -323,27 +345,75 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.timestep_scale_multiplier = nn.Parameter(
|
||||||
|
torch.tensor(1000.0, dtype=torch.float32)
|
||||||
|
)
|
||||||
|
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||||
|
output_channel * 2, 0, operations=ops,
|
||||||
|
)
|
||||||
|
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||||
|
|
||||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Decoder` class."""
|
r"""The forward method of the `Decoder` class."""
|
||||||
# assert target_shape is not None, "target_shape must be provided"
|
batch_size = sample.shape[0]
|
||||||
|
|
||||||
sample = self.conv_in(sample, causal=self.causal)
|
sample = self.conv_in(sample, causal=self.causal)
|
||||||
|
|
||||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
|
||||||
|
|
||||||
checkpoint_fn = (
|
checkpoint_fn = (
|
||||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||||
if self.gradient_checkpointing and self.training
|
if self.gradient_checkpointing and self.training
|
||||||
else lambda x: x
|
else lambda x: x
|
||||||
)
|
)
|
||||||
|
|
||||||
sample = sample.to(upscale_dtype)
|
scaled_timestep = None
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
assert (
|
||||||
|
timestep is not None
|
||||||
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
|
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||||
|
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||||
|
sample = checkpoint_fn(up_block)(
|
||||||
|
sample, causal=self.causal, timestep=scaled_timestep
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||||
|
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
|
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
embedded_timestep = self.last_time_embedder(
|
||||||
|
timestep=scaled_timestep.flatten(),
|
||||||
|
resolution=None,
|
||||||
|
aspect_ratio=None,
|
||||||
|
batch_size=sample.shape[0],
|
||||||
|
hidden_dtype=sample.dtype,
|
||||||
|
)
|
||||||
|
embedded_timestep = embedded_timestep.view(
|
||||||
|
batch_size, embedded_timestep.shape[-1], 1, 1, 1
|
||||||
|
)
|
||||||
|
ada_values = self.last_scale_shift_table[
|
||||||
|
None, ..., None, None, None
|
||||||
|
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
|
||||||
|
batch_size,
|
||||||
|
2,
|
||||||
|
-1,
|
||||||
|
embedded_timestep.shape[-3],
|
||||||
|
embedded_timestep.shape[-2],
|
||||||
|
embedded_timestep.shape[-1],
|
||||||
|
)
|
||||||
|
shift, scale = ada_values.unbind(dim=1)
|
||||||
|
sample = sample * (1 + scale) + shift
|
||||||
|
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
sample = self.conv_out(sample, causal=self.causal)
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
|
|
||||||
@@ -379,12 +449,21 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
resnet_eps: float = 1e-6,
|
resnet_eps: float = 1e-6,
|
||||||
resnet_groups: int = 32,
|
resnet_groups: int = 32,
|
||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
|
inject_noise: bool = False,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
resnet_groups = (
|
resnet_groups = (
|
||||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||||
|
in_channels * 4, 0, operations=ops,
|
||||||
|
)
|
||||||
|
|
||||||
self.res_blocks = nn.ModuleList(
|
self.res_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResnetBlock3D(
|
ResnetBlock3D(
|
||||||
@@ -395,25 +474,48 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
groups=resnet_groups,
|
groups=resnet_groups,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=inject_noise,
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
timestep_embed = None
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
assert (
|
||||||
|
timestep is not None
|
||||||
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
timestep_embed = self.time_embedder(
|
||||||
|
timestep=timestep.flatten(),
|
||||||
|
resolution=None,
|
||||||
|
aspect_ratio=None,
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
timestep_embed = timestep_embed.view(
|
||||||
|
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
for resnet in self.res_blocks:
|
for resnet in self.res_blocks:
|
||||||
hidden_states = resnet(hidden_states, causal=causal)
|
hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class DepthToSpaceUpsample(nn.Module):
|
class DepthToSpaceUpsample(nn.Module):
|
||||||
def __init__(self, dims, in_channels, stride, residual=False):
|
def __init__(
|
||||||
|
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.out_channels = math.prod(stride) * in_channels
|
self.out_channels = (
|
||||||
|
math.prod(stride) * in_channels // out_channels_reduction_factor
|
||||||
|
)
|
||||||
self.conv = make_conv_nd(
|
self.conv = make_conv_nd(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@@ -423,8 +525,9 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
self.residual = residual
|
self.residual = residual
|
||||||
|
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||||
|
|
||||||
def forward(self, x, causal: bool = True):
|
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||||
if self.residual:
|
if self.residual:
|
||||||
# Reshape and duplicate the input to match the output shape
|
# Reshape and duplicate the input to match the output shape
|
||||||
x_in = rearrange(
|
x_in = rearrange(
|
||||||
@@ -434,7 +537,8 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
p2=self.stride[1],
|
p2=self.stride[1],
|
||||||
p3=self.stride[2],
|
p3=self.stride[2],
|
||||||
)
|
)
|
||||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||||
|
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||||
if self.stride[0] == 2:
|
if self.stride[0] == 2:
|
||||||
x_in = x_in[:, :, 1:, :, :]
|
x_in = x_in[:, :, 1:, :, :]
|
||||||
x = self.conv(x, causal=causal)
|
x = self.conv(x, causal=causal)
|
||||||
@@ -451,7 +555,6 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
x = x + x_in
|
x = x + x_in
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -486,11 +589,14 @@ class ResnetBlock3D(nn.Module):
|
|||||||
groups: int = 32,
|
groups: int = 32,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
|
inject_noise: bool = False,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
self.inject_noise = inject_noise
|
||||||
|
|
||||||
if norm_layer == "group_norm":
|
if norm_layer == "group_norm":
|
||||||
self.norm1 = nn.GroupNorm(
|
self.norm1 = nn.GroupNorm(
|
||||||
@@ -513,6 +619,9 @@ class ResnetBlock3D(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if inject_noise:
|
||||||
|
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||||
|
|
||||||
if norm_layer == "group_norm":
|
if norm_layer == "group_norm":
|
||||||
self.norm2 = nn.GroupNorm(
|
self.norm2 = nn.GroupNorm(
|
||||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||||
@@ -534,6 +643,9 @@ class ResnetBlock3D(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if inject_noise:
|
||||||
|
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||||
|
|
||||||
self.conv_shortcut = (
|
self.conv_shortcut = (
|
||||||
make_linear_nd(
|
make_linear_nd(
|
||||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||||
@@ -548,29 +660,84 @@ class ResnetBlock3D(nn.Module):
|
|||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.scale_shift_table = nn.Parameter(
|
||||||
|
torch.randn(4, in_channels) / in_channels**0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
def _feed_spatial_noise(
|
||||||
|
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
spatial_shape = hidden_states.shape[-2:]
|
||||||
|
device = hidden_states.device
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
# similar to the "explicit noise inputs" method in style-gan
|
||||||
|
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
||||||
|
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
||||||
|
hidden_states = hidden_states + scaled_noise
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_tensor: torch.FloatTensor,
|
input_tensor: torch.FloatTensor,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
|
timestep: Optional[torch.Tensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
hidden_states = input_tensor
|
hidden_states = input_tensor
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
hidden_states = self.norm1(hidden_states)
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
assert (
|
||||||
|
timestep is not None
|
||||||
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
|
ada_values = self.scale_shift_table[
|
||||||
|
None, ..., None, None, None
|
||||||
|
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
|
||||||
|
batch_size,
|
||||||
|
4,
|
||||||
|
-1,
|
||||||
|
timestep.shape[-3],
|
||||||
|
timestep.shape[-2],
|
||||||
|
timestep.shape[-1],
|
||||||
|
)
|
||||||
|
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
|
||||||
|
|
||||||
|
hidden_states = hidden_states * (1 + scale1) + shift1
|
||||||
|
|
||||||
hidden_states = self.non_linearity(hidden_states)
|
hidden_states = self.non_linearity(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||||
|
|
||||||
|
if self.inject_noise:
|
||||||
|
hidden_states = self._feed_spatial_noise(
|
||||||
|
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self.norm2(hidden_states)
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
hidden_states = hidden_states * (1 + scale2) + shift2
|
||||||
|
|
||||||
hidden_states = self.non_linearity(hidden_states)
|
hidden_states = self.non_linearity(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||||
|
|
||||||
|
if self.inject_noise:
|
||||||
|
hidden_states = self._feed_spatial_noise(
|
||||||
|
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
input_tensor = self.norm3(input_tensor)
|
input_tensor = self.norm3(input_tensor)
|
||||||
|
|
||||||
|
batch_size = input_tensor.shape[0]
|
||||||
|
|
||||||
input_tensor = self.conv_shortcut(input_tensor)
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
output_tensor = input_tensor + hidden_states
|
output_tensor = input_tensor + hidden_states
|
||||||
@@ -634,33 +801,71 @@ class processor(nn.Module):
|
|||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, version=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = {
|
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
if version == 0:
|
||||||
"dims": 3,
|
config = {
|
||||||
"in_channels": 3,
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
"out_channels": 3,
|
"dims": 3,
|
||||||
"latent_channels": 128,
|
"in_channels": 3,
|
||||||
"blocks": [
|
"out_channels": 3,
|
||||||
["res_x", 4],
|
"latent_channels": 128,
|
||||||
["compress_all", 1],
|
"blocks": [
|
||||||
["res_x_y", 1],
|
["res_x", 4],
|
||||||
["res_x", 3],
|
["compress_all", 1],
|
||||||
["compress_all", 1],
|
["res_x_y", 1],
|
||||||
["res_x_y", 1],
|
["res_x", 3],
|
||||||
["res_x", 3],
|
["compress_all", 1],
|
||||||
["compress_all", 1],
|
["res_x_y", 1],
|
||||||
["res_x", 3],
|
["res_x", 3],
|
||||||
["res_x", 4],
|
["compress_all", 1],
|
||||||
],
|
["res_x", 3],
|
||||||
"scaling_factor": 1.0,
|
["res_x", 4],
|
||||||
"norm_layer": "pixel_norm",
|
],
|
||||||
"patch_size": 4,
|
"scaling_factor": 1.0,
|
||||||
"latent_log_var": "uniform",
|
"norm_layer": "pixel_norm",
|
||||||
"use_quant_conv": False,
|
"patch_size": 4,
|
||||||
"causal_decoder": False,
|
"latent_log_var": "uniform",
|
||||||
}
|
"use_quant_conv": False,
|
||||||
|
"causal_decoder": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config = {
|
||||||
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
|
"dims": 3,
|
||||||
|
"in_channels": 3,
|
||||||
|
"out_channels": 3,
|
||||||
|
"latent_channels": 128,
|
||||||
|
"decoder_blocks": [
|
||||||
|
["res_x", {"num_layers": 5, "inject_noise": True}],
|
||||||
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
|
["res_x", {"num_layers": 6, "inject_noise": True}],
|
||||||
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
|
["res_x", {"num_layers": 7, "inject_noise": True}],
|
||||||
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
|
["res_x", {"num_layers": 8, "inject_noise": False}]
|
||||||
|
],
|
||||||
|
"encoder_blocks": [
|
||||||
|
["res_x", {"num_layers": 4}],
|
||||||
|
["compress_all", {}],
|
||||||
|
["res_x_y", 1],
|
||||||
|
["res_x", {"num_layers": 3}],
|
||||||
|
["compress_all", {}],
|
||||||
|
["res_x_y", 1],
|
||||||
|
["res_x", {"num_layers": 3}],
|
||||||
|
["compress_all", {}],
|
||||||
|
["res_x", {"num_layers": 3}],
|
||||||
|
["res_x", {"num_layers": 4}]
|
||||||
|
],
|
||||||
|
"scaling_factor": 1.0,
|
||||||
|
"norm_layer": "pixel_norm",
|
||||||
|
"patch_size": 4,
|
||||||
|
"latent_log_var": "uniform",
|
||||||
|
"use_quant_conv": False,
|
||||||
|
"causal_decoder": False,
|
||||||
|
"timestep_conditioning": True,
|
||||||
|
}
|
||||||
|
|
||||||
double_z = config.get("double_z", True)
|
double_z = config.get("double_z", True)
|
||||||
latent_log_var = config.get(
|
latent_log_var = config.get(
|
||||||
@@ -671,7 +876,7 @@ class VideoVAE(nn.Module):
|
|||||||
dims=config["dims"],
|
dims=config["dims"],
|
||||||
in_channels=config.get("in_channels", 3),
|
in_channels=config.get("in_channels", 3),
|
||||||
out_channels=config["latent_channels"],
|
out_channels=config["latent_channels"],
|
||||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||||
patch_size=config.get("patch_size", 1),
|
patch_size=config.get("patch_size", 1),
|
||||||
latent_log_var=latent_log_var,
|
latent_log_var=latent_log_var,
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
@@ -681,18 +886,22 @@ class VideoVAE(nn.Module):
|
|||||||
dims=config["dims"],
|
dims=config["dims"],
|
||||||
in_channels=config["latent_channels"],
|
in_channels=config["latent_channels"],
|
||||||
out_channels=config.get("out_channels", 3),
|
out_channels=config.get("out_channels", 3),
|
||||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||||
patch_size=config.get("patch_size", 1),
|
patch_size=config.get("patch_size", 1),
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
causal=config.get("causal_decoder", False),
|
causal=config.get("causal_decoder", False),
|
||||||
|
timestep_conditioning=config.get("timestep_conditioning", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
if self.timestep_conditioning: #TODO: seed
|
||||||
|
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||||
|
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .dual_conv3d import DualConv3d
|
from .dual_conv3d import DualConv3d
|
||||||
from .causal_conv3d import CausalConv3d
|
from .causal_conv3d import CausalConv3d
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
from comfy.ldm.util import instantiate_from_config
|
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
||||||
from comfy.ldm.modules.ema import LitEma
|
from comfy.ldm.modules.ema import LitEma
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
@@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
|||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self, decay=ema_decay)
|
self.model_ema = LitEma(self, decay=ema_decay)
|
||||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
def get_input(self, batch) -> Any:
|
def get_input(self, batch) -> Any:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
|
|||||||
self.model_ema.store(self.parameters())
|
self.model_ema.store(self.parameters())
|
||||||
self.model_ema.copy_to(self)
|
self.model_ema.copy_to(self)
|
||||||
if context is not None:
|
if context is not None:
|
||||||
logpy.info(f"{context}: Switched to EMA weights")
|
logging.info(f"{context}: Switched to EMA weights")
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema.restore(self.parameters())
|
self.model_ema.restore(self.parameters())
|
||||||
if context is not None:
|
if context is not None:
|
||||||
logpy.info(f"{context}: Restored training weights")
|
logging.info(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||||
raise NotImplementedError("encode()-method of abstract base class called")
|
raise NotImplementedError("encode()-method of abstract base class called")
|
||||||
@@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
|||||||
raise NotImplementedError("decode()-method of abstract base class called")
|
raise NotImplementedError("decode()-method of abstract base class called")
|
||||||
|
|
||||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||||
return get_obj_from_str(cfg["target"])(
|
return get_obj_from_str(cfg["target"])(
|
||||||
params, lr=lr, **cfg.get("params", dict())
|
params, lr=lr, **cfg.get("params", dict())
|
||||||
)
|
)
|
||||||
@@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
|
|
||||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
self.regularization = instantiate_from_config(
|
||||||
regularizer_config
|
regularizer_config
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,12 +162,19 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
|||||||
},
|
},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
|
|
||||||
|
if ddconfig.get("conv3d", False):
|
||||||
|
conv_op = comfy.ops.disable_weight_init.Conv3d
|
||||||
|
else:
|
||||||
|
conv_op = comfy.ops.disable_weight_init.Conv2d
|
||||||
|
|
||||||
|
self.quant_conv = conv_op(
|
||||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||||
(1 + ddconfig["double_z"]) * embed_dim,
|
(1 + ddconfig["double_z"]) * embed_dim,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
|
||||||
|
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def get_autoencoder_params(self) -> list:
|
def get_autoencoder_params(self) -> list:
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ if model_management.xformers_enabled():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
|
if model_management.sage_attention_enabled():
|
||||||
|
from sageattention import sageattn
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@@ -86,7 +89,7 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -139,16 +142,23 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
out = (
|
|
||||||
out.unsqueeze(0)
|
if skip_output_reshape:
|
||||||
.reshape(b, heads, -1, dim_head)
|
out = (
|
||||||
.permute(0, 2, 1, 3)
|
out.unsqueeze(0)
|
||||||
.reshape(b, -1, heads * dim_head)
|
.reshape(b, heads, -1, dim_head)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -157,8 +167,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
b, _, dim_head = query.shape
|
b, _, dim_head = query.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
query = query.reshape(b * heads, -1, dim_head)
|
query = query.reshape(b * heads, -1, dim_head)
|
||||||
value = value.reshape(b * heads, -1, dim_head)
|
value = value.reshape(b * heads, -1, dim_head)
|
||||||
@@ -177,9 +185,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||||
batch_x_heads, q_tokens, _ = query.shape
|
batch_x_heads, q_tokens, _ = query.shape
|
||||||
_, _, k_tokens = key.shape
|
_, _, k_tokens = key.shape
|
||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
|
||||||
|
|
||||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
mem_free_total, _ = model_management.get_free_memory(query.device, True)
|
||||||
|
|
||||||
kv_chunk_size_min = None
|
kv_chunk_size_min = None
|
||||||
kv_chunk_size = None
|
kv_chunk_size = None
|
||||||
@@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
if skip_output_reshape:
|
||||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads))
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -230,7 +239,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
h = heads
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||||
@@ -327,12 +335,18 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r1 = (
|
if skip_output_reshape:
|
||||||
r1.unsqueeze(0)
|
r1 = (
|
||||||
.reshape(b, heads, -1, dim_head)
|
r1.unsqueeze(0)
|
||||||
.permute(0, 2, 1, 3)
|
.reshape(b, heads, -1, dim_head)
|
||||||
.reshape(b, -1, heads * dim_head)
|
)
|
||||||
)
|
else:
|
||||||
|
r1 = (
|
||||||
|
r1.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
return r1
|
return r1
|
||||||
|
|
||||||
BROKEN_XFORMERS = False
|
BROKEN_XFORMERS = False
|
||||||
@@ -343,13 +357,10 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
b = q.shape[0]
|
||||||
b, _, _, dim_head = q.shape
|
dim_head = q.shape[-1]
|
||||||
else:
|
# check to make sure xformers isn't broken
|
||||||
b, _, dim_head = q.shape
|
|
||||||
dim_head //= heads
|
|
||||||
|
|
||||||
disabled_xformers = False
|
disabled_xformers = False
|
||||||
|
|
||||||
if BROKEN_XFORMERS:
|
if BROKEN_XFORMERS:
|
||||||
@@ -364,31 +375,43 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
q, k, v = map(
|
# b h k d -> b k h d
|
||||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
q, k, v = map(
|
||||||
|
lambda t: t.permute(0, 2, 1, 3),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
# actually do the reshaping
|
||||||
else:
|
else:
|
||||||
|
dim_head //= heads
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
# add a singleton batch dimension
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a singleton heads dimension
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
# pad to a multiple of 8
|
||||||
pad = 8 - mask.shape[-1] % 8
|
pad = 8 - mask.shape[-1] % 8
|
||||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
|
||||||
|
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
|
||||||
|
# in flux, this matrix ends up being over 1GB
|
||||||
|
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
|
||||||
|
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
mask_out[..., :mask.shape[-1]] = mask
|
mask_out[..., :mask.shape[-1]] = mask
|
||||||
|
# doesn't this remove the padding again??
|
||||||
mask = mask_out[..., :mask.shape[-1]]
|
mask = mask_out[..., :mask.shape[-1]]
|
||||||
|
mask = mask.expand(b, heads, -1, -1)
|
||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_output_reshape:
|
||||||
out = (
|
out = out.permute(0, 2, 1, 3)
|
||||||
out.unsqueeze(0)
|
|
||||||
.reshape(b, heads, -1, dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b, -1, heads * dim_head)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
out = (
|
out = (
|
||||||
out.reshape(b, -1, heads * dim_head)
|
out.reshape(b, -1, heads * dim_head)
|
||||||
@@ -403,7 +426,7 @@ else:
|
|||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
@@ -414,32 +437,90 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
if SDP_BATCH_LIMIT >= q.shape[0]:
|
if mask is not None:
|
||||||
|
# add a batch dimension if there isn't already one
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
out = (
|
if not skip_output_reshape:
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out = (
|
||||||
)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||||
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
|
for i in range(0, b, SDP_BATCH_LIMIT):
|
||||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
m = mask
|
||||||
|
if mask is not None:
|
||||||
|
if mask.shape[0] > 1:
|
||||||
|
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||||
|
|
||||||
|
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q[i : i + SDP_BATCH_LIMIT],
|
||||||
|
k[i : i + SDP_BATCH_LIMIT],
|
||||||
|
v[i : i + SDP_BATCH_LIMIT],
|
||||||
|
attn_mask=m,
|
||||||
|
dropout_p=0.0, is_causal=False
|
||||||
|
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
tensor_layout="HND"
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
tensor_layout="NHD"
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# add a batch dimension if there isn't already one
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
|
if tensor_layout == "HND":
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if skip_output_reshape:
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
else:
|
||||||
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
logging.info("Using xformers cross attention")
|
logging.info("Using sage attention")
|
||||||
|
optimized_attention = attention_sage
|
||||||
|
elif model_management.xformers_enabled():
|
||||||
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch cross attention")
|
logging.info("Using pytorch attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
else:
|
else:
|
||||||
if args.use_split_cross_attention:
|
if args.use_split_cross_attention:
|
||||||
logging.info("Using split optimization for cross attention")
|
logging.info("Using split optimization for attention")
|
||||||
optimized_attention = attention_split
|
optimized_attention = attention_split
|
||||||
else:
|
else:
|
||||||
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
optimized_attention_masked = optimized_attention
|
optimized_attention_masked = optimized_attention
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
from functools import partial
|
||||||
import math
|
|
||||||
from typing import Dict, Optional, List
|
from typing import Dict, Optional, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -72,45 +71,33 @@ class PatchEmbed(nn.Module):
|
|||||||
strict_img_size: bool = True,
|
strict_img_size: bool = True,
|
||||||
dynamic_img_pad: bool = True,
|
dynamic_img_pad: bool = True,
|
||||||
padding_mode='circular',
|
padding_mode='circular',
|
||||||
|
conv3d=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = (patch_size, patch_size)
|
try:
|
||||||
|
len(patch_size)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
except:
|
||||||
|
if conv3d:
|
||||||
|
self.patch_size = (patch_size, patch_size, patch_size)
|
||||||
|
else:
|
||||||
|
self.patch_size = (patch_size, patch_size)
|
||||||
self.padding_mode = padding_mode
|
self.padding_mode = padding_mode
|
||||||
if img_size is not None:
|
|
||||||
self.img_size = (img_size, img_size)
|
|
||||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
|
||||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
||||||
else:
|
|
||||||
self.img_size = None
|
|
||||||
self.grid_size = None
|
|
||||||
self.num_patches = None
|
|
||||||
|
|
||||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||||
self.flatten = flatten
|
self.flatten = flatten
|
||||||
self.strict_img_size = strict_img_size
|
self.strict_img_size = strict_img_size
|
||||||
self.dynamic_img_pad = dynamic_img_pad
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
if conv3d:
|
||||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# B, C, H, W = x.shape
|
|
||||||
# if self.img_size is not None:
|
|
||||||
# if self.strict_img_size:
|
|
||||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
|
||||||
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
|
||||||
# elif not self.dynamic_img_pad:
|
|
||||||
# _assert(
|
|
||||||
# H % self.patch_size[0] == 0,
|
|
||||||
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
|
||||||
# )
|
|
||||||
# _assert(
|
|
||||||
# W % self.patch_size[1] == 0,
|
|
||||||
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
|
||||||
# )
|
|
||||||
if self.dynamic_img_pad:
|
if self.dynamic_img_pad:
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional, Any
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
@@ -44,51 +43,100 @@ def Normalize(in_channels, num_groups=32):
|
|||||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoConv3d(nn.Module):
|
||||||
|
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
if padding != 0:
|
||||||
|
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
|
||||||
|
else:
|
||||||
|
kwargs["padding"] = padding
|
||||||
|
|
||||||
|
self.padding = padding
|
||||||
|
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.padding != 0:
|
||||||
|
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
def interpolate_up(x, scale_factor):
|
||||||
|
try:
|
||||||
|
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||||
|
except: #operation not implemented for bf16
|
||||||
|
orig_shape = list(x.shape)
|
||||||
|
out_shape = orig_shape[:2]
|
||||||
|
for i in range(len(orig_shape) - 2):
|
||||||
|
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
||||||
|
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
||||||
|
split = 8
|
||||||
|
l = out.shape[1] // split
|
||||||
|
for i in range(0, out.shape[1], l):
|
||||||
|
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
def __init__(self, in_channels, with_conv):
|
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = ops.Conv2d(in_channels,
|
self.conv = conv_op(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
try:
|
scale_factor = self.scale_factor
|
||||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
if isinstance(scale_factor, (int, float)):
|
||||||
except: #operation not implemented for bf16
|
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||||
b, c, h, w = x.shape
|
|
||||||
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
|
|
||||||
split = 8
|
|
||||||
l = out.shape[1] // split
|
|
||||||
for i in range(0, out.shape[1], l):
|
|
||||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
|
|
||||||
del x
|
|
||||||
x = out
|
|
||||||
|
|
||||||
|
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||||
|
t = x.shape[2]
|
||||||
|
if t > 1:
|
||||||
|
a, b = x.split((1, t - 1), dim=2)
|
||||||
|
del x
|
||||||
|
b = interpolate_up(b, scale_factor)
|
||||||
|
else:
|
||||||
|
a = x
|
||||||
|
|
||||||
|
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
||||||
|
if t > 1:
|
||||||
|
x = torch.cat((a, b), dim=2)
|
||||||
|
else:
|
||||||
|
x = a
|
||||||
|
else:
|
||||||
|
x = interpolate_up(x, scale_factor)
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample(nn.Module):
|
||||||
def __init__(self, in_channels, with_conv):
|
def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = ops.Conv2d(in_channels,
|
self.conv = conv_op(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=stride,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
pad = (0,1,0,1)
|
if x.ndim == 4:
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
pad = (0, 1, 0, 1)
|
||||||
|
mode = "constant"
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||||
|
elif x.ndim == 5:
|
||||||
|
pad = (1, 1, 1, 1, 2, 0)
|
||||||
|
mode = "replicate"
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode=mode)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
else:
|
else:
|
||||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||||
@@ -97,7 +145,7 @@ class Downsample(nn.Module):
|
|||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||||
dropout, temb_channels=512):
|
dropout, temb_channels=512, conv_op=ops.Conv2d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
@@ -106,7 +154,7 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
self.swish = torch.nn.SiLU(inplace=True)
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
self.norm1 = Normalize(in_channels)
|
self.norm1 = Normalize(in_channels)
|
||||||
self.conv1 = ops.Conv2d(in_channels,
|
self.conv1 = conv_op(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@@ -116,20 +164,20 @@ class ResnetBlock(nn.Module):
|
|||||||
out_channels)
|
out_channels)
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = Normalize(out_channels)
|
||||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.conv2 = ops.Conv2d(out_channels,
|
self.conv2 = conv_op(out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = ops.Conv2d(in_channels,
|
self.conv_shortcut = conv_op(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = ops.Conv2d(in_channels,
|
self.nin_shortcut = conv_op(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@@ -163,7 +211,6 @@ def slice_attention(q, k, v):
|
|||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
@@ -196,21 +243,25 @@ def slice_attention(q, k, v):
|
|||||||
|
|
||||||
def normal_attention(q, k, v):
|
def normal_attention(q, k, v):
|
||||||
# compute attention
|
# compute attention
|
||||||
b,c,h,w = q.shape
|
orig_shape = q.shape
|
||||||
|
b = orig_shape[0]
|
||||||
|
c = orig_shape[1]
|
||||||
|
|
||||||
q = q.reshape(b,c,h*w)
|
q = q.reshape(b, c, -1)
|
||||||
q = q.permute(0,2,1) # b,hw,c
|
q = q.permute(0, 2, 1) # b,hw,c
|
||||||
k = k.reshape(b,c,h*w) # b,c,hw
|
k = k.reshape(b, c, -1) # b,c,hw
|
||||||
v = v.reshape(b,c,h*w)
|
v = v.reshape(b, c, -1)
|
||||||
|
|
||||||
r1 = slice_attention(q, k, v)
|
r1 = slice_attention(q, k, v)
|
||||||
h_ = r1.reshape(b,c,h,w)
|
h_ = r1.reshape(orig_shape)
|
||||||
del r1
|
del r1
|
||||||
return h_
|
return h_
|
||||||
|
|
||||||
def xformers_attention(q, k, v):
|
def xformers_attention(q, k, v):
|
||||||
# compute attention
|
# compute attention
|
||||||
B, C, H, W = q.shape
|
orig_shape = q.shape
|
||||||
|
B = orig_shape[0]
|
||||||
|
C = orig_shape[1]
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
@@ -218,14 +269,16 @@ def xformers_attention(q, k, v):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
out = out.transpose(1, 2).reshape(orig_shape)
|
||||||
except NotImplementedError as e:
|
except NotImplementedError:
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def pytorch_attention(q, k, v):
|
def pytorch_attention(q, k, v):
|
||||||
# compute attention
|
# compute attention
|
||||||
B, C, H, W = q.shape
|
orig_shape = q.shape
|
||||||
|
B = orig_shape[0]
|
||||||
|
C = orig_shape[1]
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
@@ -233,49 +286,52 @@ def pytorch_attention(q, k, v):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def vae_attention():
|
||||||
|
if model_management.xformers_enabled_vae():
|
||||||
|
logging.info("Using xformers attention in VAE")
|
||||||
|
return xformers_attention
|
||||||
|
elif model_management.pytorch_attention_enabled():
|
||||||
|
logging.info("Using pytorch attention in VAE")
|
||||||
|
return pytorch_attention
|
||||||
|
else:
|
||||||
|
logging.info("Using split attention in VAE")
|
||||||
|
return normal_attention
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = ops.Conv2d(in_channels,
|
self.q = conv_op(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.k = ops.Conv2d(in_channels,
|
self.k = conv_op(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.v = ops.Conv2d(in_channels,
|
self.v = conv_op(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.proj_out = ops.Conv2d(in_channels,
|
self.proj_out = conv_op(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
self.optimized_attention = vae_attention()
|
||||||
logging.info("Using xformers attention in VAE")
|
|
||||||
self.optimized_attention = xformers_attention
|
|
||||||
elif model_management.pytorch_attention_enabled():
|
|
||||||
logging.info("Using pytorch attention in VAE")
|
|
||||||
self.optimized_attention = pytorch_attention
|
|
||||||
else:
|
|
||||||
logging.info("Using split attention in VAE")
|
|
||||||
self.optimized_attention = normal_attention
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
@@ -291,8 +347,8 @@ class AttnBlock(nn.Module):
|
|||||||
return x+h_
|
return x+h_
|
||||||
|
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
|
||||||
return AttnBlock(in_channels)
|
return AttnBlock(in_channels, conv_op=conv_op)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@@ -451,6 +507,7 @@ class Encoder(nn.Module):
|
|||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||||
|
conv3d=False, time_compress=None,
|
||||||
**ignore_kwargs):
|
**ignore_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn: attn_type = "linear"
|
||||||
@@ -461,8 +518,15 @@ class Encoder(nn.Module):
|
|||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
if conv3d:
|
||||||
|
conv_op = VideoConv3d
|
||||||
|
mid_attn_conv_op = ops.Conv3d
|
||||||
|
else:
|
||||||
|
conv_op = ops.Conv2d
|
||||||
|
mid_attn_conv_op = ops.Conv2d
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = ops.Conv2d(in_channels,
|
self.conv_in = conv_op(in_channels,
|
||||||
self.ch,
|
self.ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@@ -481,15 +545,20 @@ class Encoder(nn.Module):
|
|||||||
block.append(ResnetBlock(in_channels=block_in,
|
block.append(ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_out,
|
out_channels=block_out,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout))
|
dropout=dropout,
|
||||||
|
conv_op=conv_op))
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
|
||||||
down = nn.Module()
|
down = nn.Module()
|
||||||
down.block = block
|
down.block = block
|
||||||
down.attn = attn
|
down.attn = attn
|
||||||
if i_level != self.num_resolutions-1:
|
if i_level != self.num_resolutions-1:
|
||||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
stride = 2
|
||||||
|
if time_compress is not None:
|
||||||
|
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||||
|
stride = (1, 2, 2)
|
||||||
|
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||||
curr_res = curr_res // 2
|
curr_res = curr_res // 2
|
||||||
self.down.append(down)
|
self.down.append(down)
|
||||||
|
|
||||||
@@ -498,16 +567,18 @@ class Encoder(nn.Module):
|
|||||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout)
|
dropout=dropout,
|
||||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
conv_op=conv_op)
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
|
||||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout)
|
dropout=dropout,
|
||||||
|
conv_op=conv_op)
|
||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = ops.Conv2d(block_in,
|
self.conv_out = conv_op(block_in,
|
||||||
2*z_channels if double_z else z_channels,
|
2*z_channels if double_z else z_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@@ -545,9 +616,10 @@ class Decoder(nn.Module):
|
|||||||
conv_out_op=ops.Conv2d,
|
conv_out_op=ops.Conv2d,
|
||||||
resnet_op=ResnetBlock,
|
resnet_op=ResnetBlock,
|
||||||
attn_op=AttnBlock,
|
attn_op=AttnBlock,
|
||||||
|
conv3d=False,
|
||||||
|
time_compress=None,
|
||||||
**ignorekwargs):
|
**ignorekwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
@@ -557,8 +629,15 @@ class Decoder(nn.Module):
|
|||||||
self.give_pre_end = give_pre_end
|
self.give_pre_end = give_pre_end
|
||||||
self.tanh_out = tanh_out
|
self.tanh_out = tanh_out
|
||||||
|
|
||||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
if conv3d:
|
||||||
in_ch_mult = (1,)+tuple(ch_mult)
|
conv_op = VideoConv3d
|
||||||
|
conv_out_op = VideoConv3d
|
||||||
|
mid_attn_conv_op = ops.Conv3d
|
||||||
|
else:
|
||||||
|
conv_op = ops.Conv2d
|
||||||
|
mid_attn_conv_op = ops.Conv2d
|
||||||
|
|
||||||
|
# compute block_in and curr_res at lowest res
|
||||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||||
@@ -566,7 +645,7 @@ class Decoder(nn.Module):
|
|||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
self.conv_in = ops.Conv2d(z_channels,
|
self.conv_in = conv_op(z_channels,
|
||||||
block_in,
|
block_in,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@@ -577,12 +656,14 @@ class Decoder(nn.Module):
|
|||||||
self.mid.block_1 = resnet_op(in_channels=block_in,
|
self.mid.block_1 = resnet_op(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout)
|
dropout=dropout,
|
||||||
self.mid.attn_1 = attn_op(block_in)
|
conv_op=conv_op)
|
||||||
|
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
|
||||||
self.mid.block_2 = resnet_op(in_channels=block_in,
|
self.mid.block_2 = resnet_op(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout)
|
dropout=dropout,
|
||||||
|
conv_op=conv_op)
|
||||||
|
|
||||||
# upsampling
|
# upsampling
|
||||||
self.up = nn.ModuleList()
|
self.up = nn.ModuleList()
|
||||||
@@ -594,15 +675,21 @@ class Decoder(nn.Module):
|
|||||||
block.append(resnet_op(in_channels=block_in,
|
block.append(resnet_op(in_channels=block_in,
|
||||||
out_channels=block_out,
|
out_channels=block_out,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout))
|
dropout=dropout,
|
||||||
|
conv_op=conv_op))
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
attn.append(attn_op(block_in))
|
attn.append(attn_op(block_in, conv_op=conv_op))
|
||||||
up = nn.Module()
|
up = nn.Module()
|
||||||
up.block = block
|
up.block = block
|
||||||
up.attn = attn
|
up.attn = attn
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
scale_factor = 2.0
|
||||||
|
if time_compress is not None:
|
||||||
|
if i_level > math.log2(time_compress):
|
||||||
|
scale_factor = (1.0, 2.0, 2.0)
|
||||||
|
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
|
||||||
curr_res = curr_res * 2
|
curr_res = curr_res * 2
|
||||||
self.up.insert(0, up) # prepend to get consistent order
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
@@ -615,9 +702,6 @@ class Decoder(nn.Module):
|
|||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
def forward(self, z, **kwargs):
|
def forward(self, z, **kwargs):
|
||||||
#assert z.shape[1:] == self.z_shape[1:]
|
|
||||||
self.last_z_shape = z.shape
|
|
||||||
|
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
temb = None
|
temb = None
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import logging
|
|||||||
from .util import (
|
from .util import (
|
||||||
checkpoint,
|
checkpoint,
|
||||||
avg_pool_nd,
|
avg_pool_nd,
|
||||||
zero_module,
|
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
AlphaBlender,
|
AlphaBlender,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import numpy as np
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from .util import extract_into_tensor, make_beta_schedule
|
from .util import extract_into_tensor, make_beta_schedule
|
||||||
from comfy.ldm.util import default
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractLowScaleModel(nn.Module):
|
class AbstractLowScaleModel(nn.Module):
|
||||||
|
|||||||
@@ -8,8 +8,8 @@
|
|||||||
# thanks!
|
# thanks!
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -131,7 +131,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
|||||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||||
steps_out = ddim_timesteps + 1
|
steps_out = ddim_timesteps + 1
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||||
return steps_out
|
return steps_out
|
||||||
|
|
||||||
|
|
||||||
@@ -143,8 +143,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
|||||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||||
print(f'For the chosen value of eta, which is {eta}, '
|
logging.info(f'For the chosen value of eta, which is {eta}, '
|
||||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||||
return sigmas, alphas, alphas_prev
|
return sigmas, alphas, alphas_prev
|
||||||
|
|
||||||
|
|||||||
@@ -30,10 +30,10 @@ class DiagonalGaussianDistribution(object):
|
|||||||
self.std = torch.exp(0.5 * self.logvar)
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
self.var = torch.exp(self.logvar)
|
self.var = torch.exp(self.logvar)
|
||||||
if self.deterministic:
|
if self.deterministic:
|
||||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def kl(self, other=None):
|
def kl(self, other=None):
|
||||||
|
|||||||
@@ -17,12 +17,11 @@ import math
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Optional, NamedTuple, List, Protocol
|
from typing import Optional, NamedTuple, List, Protocol
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing import Optional, NamedTuple, List
|
from typing import Optional, NamedTuple, List
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
@@ -172,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
del attn_scores
|
del attn_scores
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||||
torch.exp(attn_scores, out=attn_scores)
|
torch.exp(attn_scores, out=attn_scores)
|
||||||
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
||||||
attn_scores /= summed
|
attn_scores /= summed
|
||||||
@@ -262,7 +261,7 @@ def efficient_dot_product_attention(
|
|||||||
value=value,
|
value=value,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||||
res = torch.cat([
|
res = torch.cat([
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
from typing import Callable, Iterable, Union
|
from typing import Iterable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
@@ -194,6 +194,7 @@ def make_time_attn(
|
|||||||
attn_kwargs=None,
|
attn_kwargs=None,
|
||||||
alpha: float = 0,
|
alpha: float = 0,
|
||||||
merge_strategy: str = "learned",
|
merge_strategy: str = "learned",
|
||||||
|
conv_op=ops.Conv2d,
|
||||||
):
|
):
|
||||||
return partialclass(
|
return partialclass(
|
||||||
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
||||||
|
|||||||
380
comfy/ldm/pixart/blocks.py
Normal file
380
comfy/ldm/pixart/blocks.py
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
# Based on:
|
||||||
|
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||||
|
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
# if model_management.xformers_enabled():
|
||||||
|
# import xformers.ops
|
||||||
|
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
|
||||||
|
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
|
||||||
|
# else:
|
||||||
|
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
def t2i_modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
|
class MultiHeadCrossAttention(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super(MultiHeadCrossAttention, self).__init__()
|
||||||
|
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = d_model // num_heads
|
||||||
|
|
||||||
|
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||||
|
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x, cond, mask=None):
|
||||||
|
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||||
|
B, N, C = x.shape
|
||||||
|
|
||||||
|
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||||
|
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||||
|
k, v = kv.unbind(2)
|
||||||
|
|
||||||
|
assert mask is None # TODO?
|
||||||
|
# # TODO: xformers needs separate mask logic here
|
||||||
|
# if model_management.xformers_enabled():
|
||||||
|
# attn_bias = None
|
||||||
|
# if mask is not None:
|
||||||
|
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
||||||
|
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
||||||
|
# else:
|
||||||
|
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||||
|
# attn_mask = None
|
||||||
|
# mask = torch.ones(())
|
||||||
|
# if mask is not None and len(mask) > 1:
|
||||||
|
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||||
|
# # But depth doesn't matter as tensors can expand in that dimension
|
||||||
|
# attn_mask_template = torch.ones(
|
||||||
|
# [q.shape[2] // B, mask[0]],
|
||||||
|
# dtype=torch.bool,
|
||||||
|
# device=q.device
|
||||||
|
# )
|
||||||
|
# attn_mask = torch.block_diag(attn_mask_template)
|
||||||
|
#
|
||||||
|
# # create a mask on the diagonal for each mask in the batch
|
||||||
|
# for _ in range(B - 1):
|
||||||
|
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||||
|
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||||
|
|
||||||
|
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionKVCompress(nn.Module):
|
||||||
|
"""Multi-head Attention block with KV token compression and qk norm."""
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
|
||||||
|
self.sr_ratio = sr_ratio
|
||||||
|
if sr_ratio > 1 and sampling == 'conv':
|
||||||
|
# Avg Conv Init.
|
||||||
|
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
|
||||||
|
# self.sr.weight.data.fill_(1/sr_ratio**2)
|
||||||
|
# self.sr.bias.data.zero_()
|
||||||
|
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
if qk_norm:
|
||||||
|
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.q_norm = nn.Identity()
|
||||||
|
self.k_norm = nn.Identity()
|
||||||
|
|
||||||
|
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
|
||||||
|
if sampling is None or scale_factor == 1:
|
||||||
|
return tensor
|
||||||
|
B, N, C = tensor.shape
|
||||||
|
|
||||||
|
if sampling == 'uniform_every':
|
||||||
|
return tensor[:, ::scale_factor], int(N // scale_factor)
|
||||||
|
|
||||||
|
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
||||||
|
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
|
||||||
|
new_N = new_H * new_W
|
||||||
|
|
||||||
|
if sampling == 'ave':
|
||||||
|
tensor = F.interpolate(
|
||||||
|
tensor, scale_factor=1 / scale_factor, mode='nearest'
|
||||||
|
).permute(0, 2, 3, 1)
|
||||||
|
elif sampling == 'uniform':
|
||||||
|
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
|
||||||
|
elif sampling == 'conv':
|
||||||
|
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
|
||||||
|
tensor = self.norm(tensor)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
return tensor.reshape(B, new_N, C).contiguous(), new_N
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, HW=None, block_id=None):
|
||||||
|
B, N, C = x.shape # 2 4096 1152
|
||||||
|
new_N = N
|
||||||
|
if HW is None:
|
||||||
|
H = W = int(N ** 0.5)
|
||||||
|
else:
|
||||||
|
H, W = HW
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||||
|
|
||||||
|
q, k, v = qkv.unbind(2)
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
# KV compression
|
||||||
|
if self.sr_ratio > 1:
|
||||||
|
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||||
|
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||||
|
|
||||||
|
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
|
||||||
|
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||||
|
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
raise NotImplementedError("Attn mask logic not added for self attention")
|
||||||
|
|
||||||
|
# This is never called at the moment
|
||||||
|
# attn_bias = None
|
||||||
|
# if mask is not None:
|
||||||
|
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
||||||
|
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
||||||
|
|
||||||
|
# attention 2
|
||||||
|
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||||
|
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
|
||||||
|
|
||||||
|
x = x.view(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of PixArt.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class T2IFinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of PixArt.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
def forward(self, x, t):
|
||||||
|
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
|
||||||
|
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of PixArt.
|
||||||
|
"""
|
||||||
|
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
def forward(self, x, t):
|
||||||
|
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of PixArt.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
def forward(self, x, t):
|
||||||
|
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_decoder(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SizeEmbedder(TimestepEmbedder):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.outdim = hidden_size
|
||||||
|
|
||||||
|
def forward(self, s, bs):
|
||||||
|
if s.ndim == 1:
|
||||||
|
s = s[:, None]
|
||||||
|
assert s.ndim == 2
|
||||||
|
if s.shape[0] != bs:
|
||||||
|
s = s.repeat(bs//s.shape[0], 1)
|
||||||
|
assert s.shape[0] == bs
|
||||||
|
b, dims = s.shape[0], s.shape[1]
|
||||||
|
s = rearrange(s, "b d -> (b d)")
|
||||||
|
s_freq = timestep_embedding(s, self.frequency_embedding_size)
|
||||||
|
s_emb = self.mlp(s_freq.to(s.dtype))
|
||||||
|
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||||
|
return s_emb
|
||||||
|
|
||||||
|
|
||||||
|
class LabelEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||||
|
"""
|
||||||
|
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
use_cfg_embedding = dropout_prob > 0
|
||||||
|
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.dropout_prob = dropout_prob
|
||||||
|
|
||||||
|
def token_drop(self, labels, force_drop_ids=None):
|
||||||
|
"""
|
||||||
|
Drops labels to enable classifier-free guidance.
|
||||||
|
"""
|
||||||
|
if force_drop_ids is None:
|
||||||
|
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
|
||||||
|
else:
|
||||||
|
drop_ids = force_drop_ids == 1
|
||||||
|
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def forward(self, labels, train, force_drop_ids=None):
|
||||||
|
use_dropout = self.dropout_prob > 0
|
||||||
|
if (train and use_dropout) or (force_drop_ids is not None):
|
||||||
|
labels = self.token_drop(labels, force_drop_ids)
|
||||||
|
embeddings = self.embedding_table(labels)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class CaptionEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.y_proj = Mlp(
|
||||||
|
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
||||||
|
self.uncond_prob = uncond_prob
|
||||||
|
|
||||||
|
def token_drop(self, caption, force_drop_ids=None):
|
||||||
|
"""
|
||||||
|
Drops labels to enable classifier-free guidance.
|
||||||
|
"""
|
||||||
|
if force_drop_ids is None:
|
||||||
|
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||||
|
else:
|
||||||
|
drop_ids = force_drop_ids == 1
|
||||||
|
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
def forward(self, caption, train, force_drop_ids=None):
|
||||||
|
if train:
|
||||||
|
assert caption.shape[2:] == self.y_embedding.shape
|
||||||
|
use_dropout = self.uncond_prob > 0
|
||||||
|
if (train and use_dropout) or (force_drop_ids is not None):
|
||||||
|
caption = self.token_drop(caption, force_drop_ids)
|
||||||
|
caption = self.y_proj(caption)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
|
||||||
|
class CaptionEmbedderDoubleBr(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = Mlp(
|
||||||
|
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
|
||||||
|
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
|
||||||
|
self.uncond_prob = uncond_prob
|
||||||
|
|
||||||
|
def token_drop(self, global_caption, caption, force_drop_ids=None):
|
||||||
|
"""
|
||||||
|
Drops labels to enable classifier-free guidance.
|
||||||
|
"""
|
||||||
|
if force_drop_ids is None:
|
||||||
|
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
|
||||||
|
else:
|
||||||
|
drop_ids = force_drop_ids == 1
|
||||||
|
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
|
||||||
|
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||||
|
return global_caption, caption
|
||||||
|
|
||||||
|
def forward(self, caption, train, force_drop_ids=None):
|
||||||
|
assert caption.shape[2: ] == self.y_embedding.shape
|
||||||
|
global_caption = caption.mean(dim=2).squeeze()
|
||||||
|
use_dropout = self.uncond_prob > 0
|
||||||
|
if (train and use_dropout) or (force_drop_ids is not None):
|
||||||
|
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
|
||||||
|
y_embed = self.proj(global_caption)
|
||||||
|
return y_embed, caption
|
||||||
256
comfy/ldm/pixart/pixartms.py
Normal file
256
comfy/ldm/pixart/pixartms.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
# Based on:
|
||||||
|
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||||
|
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .blocks import (
|
||||||
|
t2i_modulate,
|
||||||
|
CaptionEmbedder,
|
||||||
|
AttentionKVCompress,
|
||||||
|
MultiHeadCrossAttention,
|
||||||
|
T2IFinalLayer,
|
||||||
|
SizeEmbedder,
|
||||||
|
)
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||||
|
grid_h, grid_w = torch.meshgrid(
|
||||||
|
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||||
|
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||||
|
indexing='ij'
|
||||||
|
)
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||||
|
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
class PixArtMSBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||||
|
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.attn = AttentionKVCompress(
|
||||||
|
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||||
|
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||||
|
)
|
||||||
|
self.cross_attn = MultiHeadCrossAttention(
|
||||||
|
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||||
|
)
|
||||||
|
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
# to be compatible with lower version pytorch
|
||||||
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||||
|
|
||||||
|
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||||
|
B, N, C = x.shape
|
||||||
|
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||||
|
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||||
|
x = x + self.cross_attn(x, y, mask)
|
||||||
|
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
### Core PixArt Model ###
|
||||||
|
class PixArtMS(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion model with a Transformer backbone.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size=32,
|
||||||
|
patch_size=2,
|
||||||
|
in_channels=4,
|
||||||
|
hidden_size=1152,
|
||||||
|
depth=28,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
class_dropout_prob=0.1,
|
||||||
|
learn_sigma=True,
|
||||||
|
pred_sigma=True,
|
||||||
|
drop_path: float = 0.,
|
||||||
|
caption_channels=4096,
|
||||||
|
pe_interpolation=None,
|
||||||
|
pe_precision=None,
|
||||||
|
config=None,
|
||||||
|
model_max_length=120,
|
||||||
|
micro_condition=True,
|
||||||
|
qk_norm=False,
|
||||||
|
kv_compress_config=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.pred_sigma = pred_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.pe_interpolation = pe_interpolation
|
||||||
|
self.pe_precision = pe_precision
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
|
self.t_block = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_channels,
|
||||||
|
embed_dim=hidden_size,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.y_embedder = CaptionEmbedder(
|
||||||
|
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||||
|
act_layer=approx_gelu, token_num=model_max_length,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.micro_conditioning = micro_condition
|
||||||
|
if self.micro_conditioning:
|
||||||
|
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
# For fixed sin-cos embedding:
|
||||||
|
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||||
|
# self.base_size = input_size // self.patch_size
|
||||||
|
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||||
|
|
||||||
|
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||||
|
if kv_compress_config is None:
|
||||||
|
kv_compress_config = {
|
||||||
|
'sampling': None,
|
||||||
|
'scale_factor': 1,
|
||||||
|
'kv_compress_layer': [],
|
||||||
|
}
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
PixArtMSBlock(
|
||||||
|
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||||
|
sampling=kv_compress_config['sampling'],
|
||||||
|
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
])
|
||||||
|
self.final_layer = T2IFinalLayer(
|
||||||
|
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Original forward pass of PixArt.
|
||||||
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
|
t: (N,) tensor of diffusion timesteps
|
||||||
|
y: (N, 1, 120, C) conditioning
|
||||||
|
ar: (N, 1): aspect ratio
|
||||||
|
cs: (N ,2) size conditioning for height/width
|
||||||
|
"""
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
c_res = (H + W) // 2
|
||||||
|
pe_interpolation = self.pe_interpolation
|
||||||
|
if pe_interpolation is None or self.pe_precision is not None:
|
||||||
|
# calculate pe_interpolation on-the-fly
|
||||||
|
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||||
|
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||||
|
self.hidden_size,
|
||||||
|
h=(H // self.patch_size),
|
||||||
|
w=(W // self.patch_size),
|
||||||
|
pe_interpolation=pe_interpolation,
|
||||||
|
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype,
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||||
|
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||||
|
|
||||||
|
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||||
|
bs = x.shape[0]
|
||||||
|
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||||
|
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||||
|
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||||
|
|
||||||
|
t0 = self.t_block(t)
|
||||||
|
y = self.y_embedder(y, self.training) # (N, D)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask.shape[0] != y.shape[0]:
|
||||||
|
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||||
|
mask = mask.squeeze(1).squeeze(1)
|
||||||
|
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||||
|
y_lens = mask.sum(dim=1).tolist()
|
||||||
|
else:
|
||||||
|
y_lens = None
|
||||||
|
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||||
|
|
||||||
|
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
|
||||||
|
# Fallback for missing microconds
|
||||||
|
if self.micro_conditioning:
|
||||||
|
if c_size is None:
|
||||||
|
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||||
|
|
||||||
|
if c_ar is None:
|
||||||
|
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||||
|
|
||||||
|
## Still accepts the input w/o that dim but returns garbage
|
||||||
|
if len(context.shape) == 3:
|
||||||
|
context = context.unsqueeze(1)
|
||||||
|
|
||||||
|
## run original forward pass
|
||||||
|
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||||
|
|
||||||
|
## only return EPS
|
||||||
|
if self.pred_sigma:
|
||||||
|
return out[:, :self.in_channels]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def unpatchify(self, x, h, w):
|
||||||
|
"""
|
||||||
|
x: (N, T, patch_size**2 * C)
|
||||||
|
imgs: (N, H, W, C)
|
||||||
|
"""
|
||||||
|
c = self.out_channels
|
||||||
|
p = self.x_embedder.patch_size[0]
|
||||||
|
h = h // self.patch_size
|
||||||
|
w = w // self.patch_size
|
||||||
|
assert h * w == x.shape[1]
|
||||||
|
|
||||||
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||||
|
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||||
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||||
|
return imgs
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import optim
|
from torch import optim
|
||||||
@@ -23,7 +24,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill="black", font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
print("Cant encode string for logging. Skipping.")
|
logging.warning("Cant encode string for logging. Skipping.")
|
||||||
|
|
||||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
@@ -65,7 +66,7 @@ def mean_flat(tensor):
|
|||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
|
|
||||||
@@ -133,7 +134,6 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
|||||||
exp_avgs = []
|
exp_avgs = []
|
||||||
exp_avg_sqs = []
|
exp_avg_sqs = []
|
||||||
ema_params_with_grad = []
|
ema_params_with_grad = []
|
||||||
state_sums = []
|
|
||||||
max_exp_avg_sqs = []
|
max_exp_avg_sqs = []
|
||||||
state_steps = []
|
state_steps = []
|
||||||
amsgrad = group['amsgrad']
|
amsgrad = group['amsgrad']
|
||||||
@@ -194,4 +194,4 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
|||||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -344,7 +344,6 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||||
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
@@ -353,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.PixArt):
|
||||||
|
diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
to = diffusers_keys[k]
|
||||||
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.HunyuanDiT):
|
if isinstance(model, comfy.model_base.HunyuanDiT):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
@@ -374,6 +387,18 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
key_map["{}".format(key_lora)] = k
|
key_map["{}".format(key_lora)] = k
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.HunyuanVideo):
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
# diffusion-pipe lora format
|
||||||
|
key_lora = k
|
||||||
|
key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.")
|
||||||
|
key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.")
|
||||||
|
key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.")
|
||||||
|
key_lora = key_lora[len("diffusion_model."):-len(".weight")]
|
||||||
|
key_map["transformer.{}".format(key_lora)] = k
|
||||||
|
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,11 +26,14 @@ from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAug
|
|||||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
||||||
import comfy.ldm.aura.mmdit
|
import comfy.ldm.aura.mmdit
|
||||||
|
import comfy.ldm.pixart.pixartms
|
||||||
import comfy.ldm.hydit.models
|
import comfy.ldm.hydit.models
|
||||||
import comfy.ldm.audio.dit
|
import comfy.ldm.audio.dit
|
||||||
import comfy.ldm.audio.embedders
|
import comfy.ldm.audio.embedders
|
||||||
import comfy.ldm.flux.model
|
import comfy.ldm.flux.model
|
||||||
import comfy.ldm.lightricks.model
|
import comfy.ldm.lightricks.model
|
||||||
|
import comfy.ldm.hunyuan_video.model
|
||||||
|
import comfy.ldm.cosmos.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@@ -145,7 +148,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = self.model_sampling.timestep(t).float()
|
t = self.model_sampling.timestep(t).float()
|
||||||
context = context.to(dtype)
|
if context is not None:
|
||||||
|
context = context.to(dtype)
|
||||||
|
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
@@ -186,9 +191,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
if len(denoise_mask.shape) == len(noise.shape):
|
if len(denoise_mask.shape) == len(noise.shape):
|
||||||
denoise_mask = denoise_mask[:,:1]
|
denoise_mask = denoise_mask[:, :1]
|
||||||
|
|
||||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
num_dim = noise.ndim - 2
|
||||||
|
denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:]))
|
||||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||||
@@ -198,12 +204,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(denoise_mask.to(device))
|
cond_concat.append(denoise_mask.to(device))
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
|
||||||
|
elif ck == "mask_inverted":
|
||||||
|
cond_concat.append(1.0 - denoise_mask.to(device))
|
||||||
else:
|
else:
|
||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
cond_concat.append(torch.ones_like(noise)[:, :1])
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||||
|
elif ck == "mask_inverted":
|
||||||
|
cond_concat.append(torch.zeros_like(noise)[:, :1])
|
||||||
data = torch.cat(cond_concat, dim=1)
|
data = torch.cat(cond_concat, dim=1)
|
||||||
return data
|
return data
|
||||||
return None
|
return None
|
||||||
@@ -291,6 +301,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
return blank_image
|
return blank_image
|
||||||
self.blank_inpaint_image_like = blank_inpaint_image_like
|
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
@@ -427,7 +440,6 @@ class SVD_img2vid(BaseModel):
|
|||||||
|
|
||||||
latent_image = kwargs.get("concat_latent_image", None)
|
latent_image = kwargs.get("concat_latent_image", None)
|
||||||
noise = kwargs.get("noise", None)
|
noise = kwargs.get("noise", None)
|
||||||
device = kwargs["device"]
|
|
||||||
|
|
||||||
if latent_image is None:
|
if latent_image is None:
|
||||||
latent_image = torch.zeros_like(noise)
|
latent_image = torch.zeros_like(noise)
|
||||||
@@ -539,6 +551,10 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
|
|
||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class IP2P:
|
class IP2P:
|
||||||
@@ -687,6 +703,7 @@ class StableAudio1(BaseModel):
|
|||||||
sd["{}{}".format(k, l)] = s[l]
|
sd["{}{}".format(k, l)] = s[l]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiT(BaseModel):
|
class HunyuanDiT(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
|
||||||
@@ -711,14 +728,31 @@ class HunyuanDiT(BaseModel):
|
|||||||
|
|
||||||
width = kwargs.get("width", 768)
|
width = kwargs.get("width", 768)
|
||||||
height = kwargs.get("height", 768)
|
height = kwargs.get("height", 768)
|
||||||
crop_w = kwargs.get("crop_w", 0)
|
|
||||||
crop_h = kwargs.get("crop_h", 0)
|
|
||||||
target_width = kwargs.get("target_width", width)
|
target_width = kwargs.get("target_width", width)
|
||||||
target_height = kwargs.get("target_height", height)
|
target_height = kwargs.get("target_height", height)
|
||||||
|
|
||||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class PixArt(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
width = kwargs.get("width", None)
|
||||||
|
height = kwargs.get("height", None)
|
||||||
|
if width is not None and height is not None:
|
||||||
|
out["c_size"] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width]]))
|
||||||
|
out["c_ar"] = comfy.conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height/width)]]))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class Flux(BaseModel):
|
class Flux(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||||
@@ -755,7 +789,6 @@ class Flux(BaseModel):
|
|||||||
mask = torch.ones_like(noise)[:, :1]
|
mask = torch.ones_like(noise)[:, :1]
|
||||||
|
|
||||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||||
print(mask.shape)
|
|
||||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
@@ -769,7 +802,20 @@ class Flux(BaseModel):
|
|||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
# upscale the attention mask, since now we
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
shape = kwargs["noise"].shape
|
||||||
|
mask_ref_size = kwargs["attention_mask_img_shape"]
|
||||||
|
# the model will pad to the patch size, and then divide
|
||||||
|
# essentially dividing and rounding up
|
||||||
|
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||||
|
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
|
||||||
|
guidance = kwargs.get("guidance", 3.5)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class GenmoMochi(BaseModel):
|
class GenmoMochi(BaseModel):
|
||||||
@@ -804,5 +850,57 @@ class LTXV(BaseModel):
|
|||||||
if guiding_latent is not None:
|
if guiding_latent is not None:
|
||||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
||||||
|
|
||||||
|
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
|
||||||
|
if guiding_latent_noise_scale is not None:
|
||||||
|
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
|
||||||
|
|
||||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class HunyuanVideo(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
return kwargs["pooled_output"]
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
guidance = kwargs.get("guidance", 6.0)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
return out
|
||||||
|
|
||||||
|
class CosmosVideo(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
if self.image_to_video:
|
||||||
|
self.concat_keys = ("mask_inverted",)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
|
||||||
|
sigma_noise_augmentation = 0 #TODO
|
||||||
|
if sigma_noise_augmentation != 0:
|
||||||
|
latent_image = latent_image + noise
|
||||||
|
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||||
|
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||||
|
|||||||
@@ -133,6 +133,26 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
unet_config["image_model"] = "hydit1"
|
unet_config["image_model"] = "hydit1"
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
|
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "hunyuan_video"
|
||||||
|
dit_config["in_channels"] = 16
|
||||||
|
dit_config["patch_size"] = [1, 2, 2]
|
||||||
|
dit_config["out_channels"] = 16
|
||||||
|
dit_config["vec_in_dim"] = 768
|
||||||
|
dit_config["context_in_dim"] = 4096
|
||||||
|
dit_config["hidden_size"] = 3072
|
||||||
|
dit_config["mlp_ratio"] = 4.0
|
||||||
|
dit_config["num_heads"] = 24
|
||||||
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
dit_config["axes_dim"] = [16, 56, 56]
|
||||||
|
dit_config["theta"] = 256
|
||||||
|
dit_config["qkv_bias"] = True
|
||||||
|
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||||
|
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
|
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "flux"
|
dit_config["image_model"] = "flux"
|
||||||
@@ -183,11 +203,87 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["rope_theta"] = 10000.0
|
dit_config["rope_theta"] = 10000.0
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
|
||||||
|
# PixArt diffusers
|
||||||
|
return None
|
||||||
|
|
||||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "ltxv"
|
dit_config["image_model"] = "ltxv"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||||
|
patch_size = 2
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["num_heads"] = 16
|
||||||
|
dit_config["patch_size"] = patch_size
|
||||||
|
dit_config["hidden_size"] = 1152
|
||||||
|
dit_config["in_channels"] = 4
|
||||||
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||||
|
|
||||||
|
y_key = "{}y_embedder.y_embedding".format(key_prefix)
|
||||||
|
if y_key in state_dict_keys:
|
||||||
|
dit_config["model_max_length"] = state_dict[y_key].shape[0]
|
||||||
|
|
||||||
|
pe_key = "{}pos_embed".format(key_prefix)
|
||||||
|
if pe_key in state_dict_keys:
|
||||||
|
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
|
||||||
|
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
|
||||||
|
|
||||||
|
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
|
||||||
|
if ar_key in state_dict_keys:
|
||||||
|
dit_config["image_model"] = "pixart_alpha"
|
||||||
|
dit_config["micro_condition"] = True
|
||||||
|
else:
|
||||||
|
dit_config["image_model"] = "pixart_sigma"
|
||||||
|
dit_config["micro_condition"] = False
|
||||||
|
return dit_config
|
||||||
|
|
||||||
|
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "cosmos"
|
||||||
|
dit_config["max_img_h"] = 240
|
||||||
|
dit_config["max_img_w"] = 240
|
||||||
|
dit_config["max_frames"] = 128
|
||||||
|
concat_padding_mask = True
|
||||||
|
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||||
|
dit_config["out_channels"] = 16
|
||||||
|
dit_config["patch_spatial"] = 2
|
||||||
|
dit_config["patch_temporal"] = 1
|
||||||
|
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["block_config"] = "FA-CA-MLP"
|
||||||
|
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||||
|
dit_config["pos_emb_cls"] = "rope3d"
|
||||||
|
dit_config["pos_emb_learnable"] = False
|
||||||
|
dit_config["pos_emb_interpolation"] = "crop"
|
||||||
|
dit_config["block_x_format"] = "THWBD"
|
||||||
|
dit_config["affline_emb_norm"] = True
|
||||||
|
dit_config["use_adaln_lora"] = True
|
||||||
|
dit_config["adaln_lora_dim"] = 256
|
||||||
|
|
||||||
|
if dit_config["model_channels"] == 4096:
|
||||||
|
# 7B
|
||||||
|
dit_config["num_blocks"] = 28
|
||||||
|
dit_config["num_heads"] = 32
|
||||||
|
dit_config["extra_per_block_abs_pos_emb"] = True
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
|
else: # 5120
|
||||||
|
# 14B
|
||||||
|
dit_config["num_blocks"] = 36
|
||||||
|
dit_config["num_heads"] = 40
|
||||||
|
dit_config["extra_per_block_abs_pos_emb"] = True
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_h_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_w_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -216,7 +312,6 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
|
|
||||||
num_res_blocks = []
|
num_res_blocks = []
|
||||||
channel_mult = []
|
channel_mult = []
|
||||||
attention_resolutions = []
|
|
||||||
transformer_depth = []
|
transformer_depth = []
|
||||||
transformer_depth_output = []
|
transformer_depth_output = []
|
||||||
context_dim = None
|
context_dim = None
|
||||||
@@ -343,6 +438,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||||
"model.model.", #audio models
|
"model.model.", #audio models
|
||||||
|
"net.", #cosmos
|
||||||
]
|
]
|
||||||
counts = {k: 0 for k in candidates}
|
counts = {k: 0 for k in candidates}
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
@@ -388,7 +484,6 @@ def convert_config(unet_config):
|
|||||||
t_out += [d] * (res + 1)
|
t_out += [d] * (res + 1)
|
||||||
s *= 2
|
s *= 2
|
||||||
transformer_depth = t_in
|
transformer_depth = t_in
|
||||||
transformer_depth_output = t_out
|
|
||||||
new_config["transformer_depth"] = t_in
|
new_config["transformer_depth"] = t_in
|
||||||
new_config["transformer_depth_output"] = t_out
|
new_config["transformer_depth_output"] = t_out
|
||||||
new_config["transformer_depth_middle"] = transformer_depth_middle
|
new_config["transformer_depth_middle"] = transformer_depth_middle
|
||||||
@@ -522,12 +617,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||||
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||||
@@ -555,6 +650,9 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||||
|
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
||||||
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
|
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
||||||
elif 'x_embedder.weight' in state_dict: #Flux
|
elif 'x_embedder.weight' in state_dict: #Flux
|
||||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ if args.directml is not None:
|
|||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
_ = torch.xpu.device_count()
|
_ = torch.xpu.device_count()
|
||||||
xpu_available = torch.xpu.is_available()
|
xpu_available = xpu_available or torch.xpu.is_available()
|
||||||
except:
|
except:
|
||||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||||
|
|
||||||
@@ -86,6 +86,13 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
_ = torch.npu.device_count()
|
||||||
|
npu_available = torch.npu.is_available()
|
||||||
|
except:
|
||||||
|
npu_available = False
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
cpu_state = CPUState.CPU
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
@@ -97,6 +104,12 @@ def is_intel_xpu():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_ascend_npu():
|
||||||
|
global npu_available
|
||||||
|
if npu_available:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@@ -110,6 +123,8 @@ def get_torch_device():
|
|||||||
else:
|
else:
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return torch.device("xpu", torch.xpu.current_device())
|
return torch.device("xpu", torch.xpu.current_device())
|
||||||
|
elif is_ascend_npu():
|
||||||
|
return torch.device("npu", torch.npu.current_device())
|
||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
@@ -130,6 +145,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_total_torch = mem_reserved
|
mem_total_torch = mem_reserved
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
|
elif is_ascend_npu():
|
||||||
|
stats = torch.npu.memory_stats(dev)
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
_, mem_total_npu = torch.npu.mem_get_info(dev)
|
||||||
|
mem_total_torch = mem_reserved
|
||||||
|
mem_total = mem_total_npu
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@@ -188,38 +209,44 @@ def is_nvidia():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_amd():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.hip:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||||
|
if is_nvidia():
|
||||||
|
MIN_WEIGHT_MEMORY_RATIO = 0.1
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = False
|
ENABLE_PYTORCH_ATTENTION = False
|
||||||
if args.use_pytorch_cross_attention:
|
if args.use_pytorch_cross_attention:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
VAE_DTYPES = [torch.float32]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
if is_intel_xpu() or is_ascend_npu():
|
||||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
|
||||||
if is_intel_xpu():
|
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if is_intel_xpu():
|
|
||||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
|
||||||
|
|
||||||
if args.cpu_vae:
|
|
||||||
VAE_DTYPES = [torch.float32]
|
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
|
||||||
|
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||||
|
except:
|
||||||
|
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
||||||
|
|
||||||
if args.lowvram:
|
if args.lowvram:
|
||||||
set_vram_to = VRAMState.LOW_VRAM
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
lowvram_available = True
|
lowvram_available = True
|
||||||
@@ -268,6 +295,8 @@ def get_torch_device_name(device):
|
|||||||
return "{}".format(device.type)
|
return "{}".format(device.type)
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||||
|
elif is_ascend_npu():
|
||||||
|
return "{} {}".format(device, torch.npu.get_device_name(device))
|
||||||
else:
|
else:
|
||||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
@@ -314,6 +343,9 @@ class LoadedModel:
|
|||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
|
||||||
|
def model_loaded_memory(self):
|
||||||
|
return self.model.loaded_size()
|
||||||
|
|
||||||
def model_offloaded_memory(self):
|
def model_offloaded_memory(self):
|
||||||
return self.model.model_size() - self.model.loaded_size()
|
return self.model.model_size() - self.model.loaded_size()
|
||||||
|
|
||||||
@@ -503,16 +535,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
vram_set_state = vram_state
|
vram_set_state = vram_state
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
loaded_memory = loaded_model.model_loaded_memory()
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
|
||||||
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 64 * 1024 * 1024
|
lowvram_model_memory = 0.1
|
||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -581,7 +613,7 @@ def unet_offload_device():
|
|||||||
|
|
||||||
def unet_inital_load_device(parameters, dtype):
|
def unet_inital_load_device(parameters, dtype):
|
||||||
torch_dev = get_torch_device()
|
torch_dev = get_torch_device()
|
||||||
if vram_state == VRAMState.HIGH_VRAM:
|
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||||
return torch_dev
|
return torch_dev
|
||||||
|
|
||||||
cpu_dev = torch.device("cpu")
|
cpu_dev = torch.device("cpu")
|
||||||
@@ -695,7 +727,7 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
|||||||
return offload_device
|
return offload_device
|
||||||
|
|
||||||
if is_device_mps(load_device):
|
if is_device_mps(load_device):
|
||||||
return offload_device
|
return load_device
|
||||||
|
|
||||||
mem_l = get_free_memory(load_device)
|
mem_l = get_free_memory(load_device)
|
||||||
mem_o = get_free_memory(offload_device)
|
mem_o = get_free_memory(offload_device)
|
||||||
@@ -738,7 +770,6 @@ def vae_offload_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_dtype(device=None, allowed_dtypes=[]):
|
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||||
global VAE_DTYPES
|
|
||||||
if args.fp16_vae:
|
if args.fp16_vae:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
elif args.bf16_vae:
|
elif args.bf16_vae:
|
||||||
@@ -747,12 +778,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
for d in allowed_dtypes:
|
for d in allowed_dtypes:
|
||||||
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
if d == torch.float16 and should_use_fp16(device):
|
||||||
return d
|
|
||||||
if d in VAE_DTYPES:
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return VAE_DTYPES[0]
|
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||||
|
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||||
|
return d
|
||||||
|
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
@@ -837,6 +870,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_supports_non_blocking(device)
|
||||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
def sage_attention_enabled():
|
||||||
|
return args.use_sage_attention
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
@@ -845,6 +880,8 @@ def xformers_enabled():
|
|||||||
return False
|
return False
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return False
|
return False
|
||||||
|
if is_ascend_npu():
|
||||||
|
return False
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILABLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
@@ -869,16 +906,23 @@ def pytorch_attention_flash_attention():
|
|||||||
return True
|
return True
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
if is_ascend_npu():
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def mac_version():
|
||||||
|
try:
|
||||||
|
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
|
||||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
macos_version = mac_version()
|
||||||
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if upcast:
|
if upcast:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
else:
|
else:
|
||||||
@@ -903,6 +947,13 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||||
mem_free_total = mem_free_xpu + mem_free_torch
|
mem_free_total = mem_free_xpu + mem_free_torch
|
||||||
|
elif is_ascend_npu():
|
||||||
|
stats = torch.npu.memory_stats(dev)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_npu, _ = torch.npu.mem_get_info(dev)
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_npu + mem_free_torch
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@@ -949,17 +1000,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if FORCE_FP16:
|
if FORCE_FP16:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
if is_device_mps(device):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
@@ -968,6 +1015,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_ascend_npu():
|
||||||
|
return True
|
||||||
|
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -1008,17 +1058,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
if is_device_mps(device):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||||
|
if mac_version() < (14,):
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
@@ -1067,19 +1115,16 @@ def soft_empty_cache(force=False):
|
|||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
|
elif is_ascend_npu():
|
||||||
|
torch.npu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.ipc_collect()
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
|
||||||
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
|
||||||
return weight
|
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
|
|
||||||
def create_model_options_clone(orig_model_options: dict):
|
def create_model_options_clone(orig_model_options: dict):
|
||||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||||
|
|
||||||
def create_hook_patches_clone(orig_hook_patches):
|
def create_hook_patches_clone(orig_hook_patches):
|
||||||
new_hook_patches = {}
|
new_hook_patches = {}
|
||||||
for hook_ref in orig_hook_patches:
|
for hook_ref in orig_hook_patches:
|
||||||
@@ -141,7 +141,7 @@ class AutoPatcherEjector:
|
|||||||
self.was_injected = False
|
self.was_injected = False
|
||||||
self.prev_skip_injection = False
|
self.prev_skip_injection = False
|
||||||
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.was_injected = False
|
self.was_injected = False
|
||||||
self.prev_skip_injection = self.model.skip_injection
|
self.prev_skip_injection = self.model.skip_injection
|
||||||
@@ -164,7 +164,7 @@ class MemoryCounter:
|
|||||||
self.value = initial
|
self.value = initial
|
||||||
self.minimum = minimum
|
self.minimum = minimum
|
||||||
# TODO: add a safe limit besides 0
|
# TODO: add a safe limit besides 0
|
||||||
|
|
||||||
def use(self, weight: torch.Tensor):
|
def use(self, weight: torch.Tensor):
|
||||||
weight_size = weight.nelement() * weight.element_size()
|
weight_size = weight.nelement() * weight.element_size()
|
||||||
if self.is_useable(weight_size):
|
if self.is_useable(weight_size):
|
||||||
@@ -210,7 +210,7 @@ class ModelPatcher:
|
|||||||
self.injections: dict[str, list[PatcherInjection]] = {}
|
self.injections: dict[str, list[PatcherInjection]] = {}
|
||||||
|
|
||||||
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
||||||
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
|
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
|
||||||
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
||||||
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
||||||
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
||||||
@@ -282,7 +282,7 @@ class ModelPatcher:
|
|||||||
n.injections[k] = i.copy()
|
n.injections[k] = i.copy()
|
||||||
# hooks
|
# hooks
|
||||||
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
||||||
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
|
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
|
||||||
for group in self.cached_hook_patches:
|
for group in self.cached_hook_patches:
|
||||||
n.cached_hook_patches[group] = {}
|
n.cached_hook_patches[group] = {}
|
||||||
for k in self.cached_hook_patches[group]:
|
for k in self.cached_hook_patches[group]:
|
||||||
@@ -402,7 +402,20 @@ class ModelPatcher:
|
|||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
def get_model_object(self, name):
|
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||||
|
"""Retrieves a nested attribute from an object using dot notation considering
|
||||||
|
object patches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the requested attribute
|
||||||
|
|
||||||
|
Example:
|
||||||
|
patcher = ModelPatcher()
|
||||||
|
weight = patcher.get_model_object("layer1.conv.weight")
|
||||||
|
"""
|
||||||
if name in self.object_patches:
|
if name in self.object_patches:
|
||||||
return self.object_patches[name]
|
return self.object_patches[name]
|
||||||
else:
|
else:
|
||||||
@@ -711,7 +724,7 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
self.backup.pop(key)
|
self.backup.pop(key)
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
if move_weight:
|
if move_weight:
|
||||||
@@ -773,7 +786,7 @@ class ModelPatcher:
|
|||||||
return self.model.device
|
return self.model.device
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
@@ -789,7 +802,7 @@ class ModelPatcher:
|
|||||||
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
||||||
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||||
c.append(callback)
|
c.append(callback)
|
||||||
|
|
||||||
def remove_callbacks_with_key(self, call_type: str, key: str):
|
def remove_callbacks_with_key(self, call_type: str, key: str):
|
||||||
c = self.callbacks.get(call_type, {})
|
c = self.callbacks.get(call_type, {})
|
||||||
if key in c:
|
if key in c:
|
||||||
@@ -797,7 +810,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def get_callbacks(self, call_type: str, key: str):
|
def get_callbacks(self, call_type: str, key: str):
|
||||||
return self.callbacks.get(call_type, {}).get(key, [])
|
return self.callbacks.get(call_type, {}).get(key, [])
|
||||||
|
|
||||||
def get_all_callbacks(self, call_type: str):
|
def get_all_callbacks(self, call_type: str):
|
||||||
c_list = []
|
c_list = []
|
||||||
for c in self.callbacks.get(call_type, {}).values():
|
for c in self.callbacks.get(call_type, {}).values():
|
||||||
@@ -810,7 +823,7 @@ class ModelPatcher:
|
|||||||
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||||
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||||
w.append(wrapper)
|
w.append(wrapper)
|
||||||
|
|
||||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||||
w = self.wrappers.get(wrapper_type, {})
|
w = self.wrappers.get(wrapper_type, {})
|
||||||
if key in w:
|
if key in w:
|
||||||
@@ -831,7 +844,7 @@ class ModelPatcher:
|
|||||||
def remove_attachments(self, key: str):
|
def remove_attachments(self, key: str):
|
||||||
if key in self.attachments:
|
if key in self.attachments:
|
||||||
self.attachments.pop(key)
|
self.attachments.pop(key)
|
||||||
|
|
||||||
def get_attachment(self, key: str):
|
def get_attachment(self, key: str):
|
||||||
return self.attachments.get(key, None)
|
return self.attachments.get(key, None)
|
||||||
|
|
||||||
@@ -842,6 +855,9 @@ class ModelPatcher:
|
|||||||
if key in self.injections:
|
if key in self.injections:
|
||||||
self.injections.pop(key)
|
self.injections.pop(key)
|
||||||
|
|
||||||
|
def get_injections(self, key: str):
|
||||||
|
return self.injections.get(key, None)
|
||||||
|
|
||||||
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
||||||
self.additional_models[key] = models
|
self.additional_models[key] = models
|
||||||
|
|
||||||
@@ -851,7 +867,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def get_additional_models_with_key(self, key: str):
|
def get_additional_models_with_key(self, key: str):
|
||||||
return self.additional_models.get(key, [])
|
return self.additional_models.get(key, [])
|
||||||
|
|
||||||
def get_additional_models(self):
|
def get_additional_models(self):
|
||||||
all_models = []
|
all_models = []
|
||||||
for models in self.additional_models.values():
|
for models in self.additional_models.values():
|
||||||
@@ -906,24 +922,25 @@ class ModelPatcher:
|
|||||||
self.model.current_patcher = self
|
self.model.current_patcher = self
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def prepare_state(self, timestep):
|
def prepare_state(self, timestep):
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||||
callback(self, timestep)
|
callback(self, timestep)
|
||||||
|
|
||||||
def restore_hook_patches(self):
|
def restore_hook_patches(self):
|
||||||
if len(self.hook_patches_backup) > 0:
|
if self.hook_patches_backup is not None:
|
||||||
self.hook_patches = self.hook_patches_backup
|
self.hook_patches = self.hook_patches_backup
|
||||||
self.hook_patches_backup = {}
|
self.hook_patches_backup = None
|
||||||
|
|
||||||
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||||
self.hook_mode = hook_mode
|
self.hook_mode = hook_mode
|
||||||
|
|
||||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||||
curr_t = t[0]
|
curr_t = t[0]
|
||||||
reset_current_hooks = False
|
reset_current_hooks = False
|
||||||
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
for hook in hook_group.hooks:
|
for hook in hook_group.hooks:
|
||||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||||
# this will cause the weights to be recalculated when sampling
|
# this will cause the weights to be recalculated when sampling
|
||||||
if changed:
|
if changed:
|
||||||
@@ -939,25 +956,26 @@ class ModelPatcher:
|
|||||||
if reset_current_hooks:
|
if reset_current_hooks:
|
||||||
self.patch_hooks(None)
|
self.patch_hooks(None)
|
||||||
|
|
||||||
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||||
|
registered: comfy.hooks.HookGroup = None):
|
||||||
self.restore_hook_patches()
|
self.restore_hook_patches()
|
||||||
registered_hooks: list[comfy.hooks.Hook] = []
|
if registered is None:
|
||||||
# handle WrapperHooks, if model_options provided
|
registered = comfy.hooks.HookGroup()
|
||||||
if model_options is not None:
|
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
|
|
||||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
|
||||||
# handle WeightHooks
|
# handle WeightHooks
|
||||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
||||||
if hook.hook_ref not in self.hook_patches:
|
if hook.hook_ref not in self.hook_patches:
|
||||||
weight_hooks_to_register.append(hook)
|
weight_hooks_to_register.append(hook)
|
||||||
|
else:
|
||||||
|
registered.add(hook)
|
||||||
if len(weight_hooks_to_register) > 0:
|
if len(weight_hooks_to_register) > 0:
|
||||||
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
||||||
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
||||||
for hook in weight_hooks_to_register:
|
for hook in weight_hooks_to_register:
|
||||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
hook.add_hook_patches(self, model_options, target_dict, registered)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||||
callback(self, hooks_dict, target)
|
callback(self, hooks, target_dict, model_options, registered)
|
||||||
|
return registered
|
||||||
|
|
||||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
@@ -975,7 +993,7 @@ class ModelPatcher:
|
|||||||
key = k[0]
|
key = k[0]
|
||||||
if len(k) > 2:
|
if len(k) > 2:
|
||||||
function = k[2]
|
function = k[2]
|
||||||
|
|
||||||
if key in model_sd:
|
if key in model_sd:
|
||||||
p.add(k)
|
p.add(k)
|
||||||
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
||||||
@@ -1008,11 +1026,11 @@ class ModelPatcher:
|
|||||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
||||||
# TODO: return transformer_options dict with any additions from hooks
|
# TODO: return transformer_options dict with any additions from hooks
|
||||||
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||||
return {}
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||||
self.patch_hooks(hooks=hooks)
|
self.patch_hooks(hooks=hooks)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||||
callback(self, hooks)
|
callback(self, hooks)
|
||||||
return {}
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
@@ -1029,7 +1047,7 @@ class ModelPatcher:
|
|||||||
if cached_weights is not None:
|
if cached_weights is not None:
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
else:
|
else:
|
||||||
@@ -1039,7 +1057,7 @@ class ModelPatcher:
|
|||||||
original_weights = self.get_key_patches()
|
original_weights = self.get_key_patches()
|
||||||
for key in relevant_patches:
|
for key in relevant_patches:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
memory_counter=memory_counter)
|
memory_counter=memory_counter)
|
||||||
@@ -1063,7 +1081,7 @@ class ModelPatcher:
|
|||||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||||
if key not in combined_patches:
|
if key not in combined_patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
if key not in self.hook_backup:
|
if key not in self.hook_backup:
|
||||||
@@ -1098,7 +1116,7 @@ class ModelPatcher:
|
|||||||
del temp_weight
|
del temp_weight
|
||||||
del out_weight
|
del out_weight
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
@@ -1107,7 +1125,7 @@ class ModelPatcher:
|
|||||||
keys = list(self.hook_backup.keys())
|
keys = list(self.hook_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
self.hook_backup.clear()
|
self.hook_backup.clear()
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|
||||||
|
|||||||
@@ -243,7 +243,7 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
|||||||
return 1.0
|
return 1.0
|
||||||
if percent >= 1.0:
|
if percent >= 1.0:
|
||||||
return 0.0
|
return 0.0
|
||||||
return 1.0 - percent
|
return time_snr_shift(self.shift, 1.0 - percent)
|
||||||
|
|
||||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
@@ -336,4 +336,4 @@ class ModelSamplingFlux(torch.nn.Module):
|
|||||||
return 1.0
|
return 1.0
|
||||||
if percent >= 1.0:
|
if percent >= 1.0:
|
||||||
return 0.0
|
return 0.0
|
||||||
return 1.0 - percent
|
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
|
||||||
|
|||||||
18
comfy/ops.py
18
comfy/ops.py
@@ -255,9 +255,10 @@ def fp8_linear(self, input):
|
|||||||
tensor_2d = True
|
tensor_2d = True
|
||||||
input = input.unsqueeze(1)
|
input = input.unsqueeze(1)
|
||||||
|
|
||||||
|
input_shape = input.shape
|
||||||
|
input_dtype = input.dtype
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
@@ -269,23 +270,24 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
|
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
else:
|
else:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input.shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -96,12 +96,12 @@ class WrapperExecutor:
|
|||||||
self.wrappers = wrappers.copy()
|
self.wrappers = wrappers.copy()
|
||||||
self.idx = idx
|
self.idx = idx
|
||||||
self.is_last = idx == len(wrappers)
|
self.is_last = idx == len(wrappers)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Calls the next wrapper or original function, whichever is appropriate."""
|
"""Calls the next wrapper or original function, whichever is appropriate."""
|
||||||
new_executor = self._create_next_executor()
|
new_executor = self._create_next_executor()
|
||||||
return new_executor.execute(*args, **kwargs)
|
return new_executor.execute(*args, **kwargs)
|
||||||
|
|
||||||
def execute(self, *args, **kwargs):
|
def execute(self, *args, **kwargs):
|
||||||
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
||||||
args = list(args)
|
args = list(args)
|
||||||
@@ -113,7 +113,7 @@ class WrapperExecutor:
|
|||||||
def _create_next_executor(self) -> 'WrapperExecutor':
|
def _create_next_executor(self) -> 'WrapperExecutor':
|
||||||
new_idx = self.idx + 1
|
new_idx = self.idx + 1
|
||||||
if new_idx > len(self.wrappers):
|
if new_idx > len(self.wrappers):
|
||||||
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
|
raise Exception("Wrapper idx exceeded available wrappers; something went very wrong.")
|
||||||
if self.class_obj is None:
|
if self.class_obj is None:
|
||||||
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
||||||
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
||||||
@@ -121,7 +121,7 @@ class WrapperExecutor:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
||||||
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
||||||
return cls(original, class_obj, wrappers, idx=idx)
|
return cls(original, class_obj, wrappers, idx=idx)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
|||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
if noise_inds is None:
|
if noise_inds is None:
|
||||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
|
|
||||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||||
noises = []
|
noises = []
|
||||||
for i in range(unique_inds[-1]+1):
|
for i in range(unique_inds[-1]+1):
|
||||||
@@ -25,9 +25,11 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
|||||||
return noises
|
return noises
|
||||||
|
|
||||||
def fix_empty_latent_channels(model, latent_image):
|
def fix_empty_latent_channels(model, latent_image):
|
||||||
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels
|
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||||
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1)
|
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||||
|
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
|
||||||
|
latent_image = latent_image.unsqueeze(2)
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import uuid
|
import uuid
|
||||||
import torch
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@@ -25,15 +24,13 @@ def get_models_from_cond(cond, model_type):
|
|||||||
models += [c[model_type]]
|
models += [c[model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
|
||||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
if 'hooks' in c:
|
if 'hooks' in c:
|
||||||
for hook in c['hooks'].hooks:
|
for hook in c['hooks'].hooks:
|
||||||
hook: comfy.hooks.Hook
|
full_hooks.add(hook)
|
||||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
|
||||||
with_type[hook] = None
|
|
||||||
if 'control' in c:
|
if 'control' in c:
|
||||||
cnets.append(c['control'])
|
cnets.append(c['control'])
|
||||||
|
|
||||||
@@ -43,7 +40,7 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
|
|||||||
if cnet.previous_controlnet is None:
|
if cnet.previous_controlnet is None:
|
||||||
return _list
|
return _list
|
||||||
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
|
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
|
||||||
|
|
||||||
hooks_list = []
|
hooks_list = []
|
||||||
cnets = set(cnets)
|
cnets = set(cnets)
|
||||||
for base_cnet in cnets:
|
for base_cnet in cnets:
|
||||||
@@ -51,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
|
|||||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||||
if extra_hooks is not None:
|
if extra_hooks is not None:
|
||||||
for hook in extra_hooks.hooks:
|
for hook in extra_hooks.hooks:
|
||||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
full_hooks.add(hook)
|
||||||
with_type[hook] = None
|
|
||||||
|
|
||||||
return hooks_dict
|
return full_hooks
|
||||||
|
|
||||||
def convert_cond(cond):
|
def convert_cond(cond):
|
||||||
out = []
|
out = []
|
||||||
@@ -62,7 +58,6 @@ def convert_cond(cond):
|
|||||||
temp = c[1].copy()
|
temp = c[1].copy()
|
||||||
model_conds = temp.get("model_conds", {})
|
model_conds = temp.get("model_conds", {})
|
||||||
if c[0] is not None:
|
if c[0] is not None:
|
||||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
|
||||||
temp["cross_attn"] = c[0]
|
temp["cross_attn"] = c[0]
|
||||||
temp["model_conds"] = model_conds
|
temp["model_conds"] = model_conds
|
||||||
temp["uuid"] = uuid.uuid4()
|
temp["uuid"] = uuid.uuid4()
|
||||||
@@ -74,13 +69,11 @@ def get_additional_models(conds, dtype):
|
|||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
gligen = []
|
gligen = []
|
||||||
add_models = []
|
add_models = []
|
||||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
cnets += get_models_from_cond(conds[k], "control")
|
cnets += get_models_from_cond(conds[k], "control")
|
||||||
gligen += get_models_from_cond(conds[k], "gligen")
|
gligen += get_models_from_cond(conds[k], "gligen")
|
||||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
|
||||||
|
|
||||||
control_nets = set(cnets)
|
control_nets = set(cnets)
|
||||||
|
|
||||||
@@ -91,11 +84,20 @@ def get_additional_models(conds, dtype):
|
|||||||
inference_memory += m.inference_memory_requirements(dtype)
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1] for x in gligen]
|
||||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
models = control_models + gligen + add_models
|
||||||
models = control_models + gligen + add_models + hook_models
|
|
||||||
|
|
||||||
return models, inference_memory
|
return models, inference_memory
|
||||||
|
|
||||||
|
def get_additional_models_from_model_options(model_options: dict[str]=None):
|
||||||
|
"""loads additional models from registered AddModels hooks"""
|
||||||
|
models = []
|
||||||
|
if model_options is not None and "registered_hooks" in model_options:
|
||||||
|
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
|
||||||
|
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||||
|
hook: comfy.hooks.AdditionalModelsHook
|
||||||
|
models.extend(hook.models)
|
||||||
|
return models
|
||||||
|
|
||||||
def cleanup_additional_models(models):
|
def cleanup_additional_models(models):
|
||||||
"""cleanup additional models that were loaded"""
|
"""cleanup additional models that were loaded"""
|
||||||
for m in models:
|
for m in models:
|
||||||
@@ -103,10 +105,10 @@ def cleanup_additional_models(models):
|
|||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
device = model.load_device
|
real_model: BaseModel = None
|
||||||
real_model: 'BaseModel' = None
|
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||||
@@ -125,12 +127,35 @@ def cleanup_models(conds, models):
|
|||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||||
|
'''
|
||||||
|
Registers hooks from conds.
|
||||||
|
'''
|
||||||
# check for hooks in conds - if not registered, see if can be applied
|
# check for hooks in conds - if not registered, see if can be applied
|
||||||
hooks = {}
|
hooks = comfy.hooks.HookGroup()
|
||||||
for k in conds:
|
for k in conds:
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
get_hooks_from_cond(conds[k], hooks)
|
||||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||||
# register hooks on model/model_options
|
# begin registering hooks
|
||||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
registered = comfy.hooks.HookGroup()
|
||||||
|
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||||
|
# handle all TransformerOptionsHooks
|
||||||
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
||||||
|
hook: comfy.hooks.TransformerOptionsHook
|
||||||
|
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||||
|
# handle all AddModelsHooks
|
||||||
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||||
|
hook: comfy.hooks.AdditionalModelsHook
|
||||||
|
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||||
|
# handle all WeightHooks by registering on ModelPatcher
|
||||||
|
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
|
||||||
|
# add registered_hooks onto model_options for further reference
|
||||||
|
if len(registered) > 0:
|
||||||
|
model_options["registered_hooks"] = registered
|
||||||
|
# merge original wrappers and callbacks with hooked wrappers and callbacks
|
||||||
|
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
|
||||||
|
for wc_name in ["wrappers", "callbacks"]:
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||||
|
copy_dict1=False)
|
||||||
|
return to_load_options
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.model_base import BaseModel
|
from comfy.model_base import BaseModel
|
||||||
from comfy.controlnet import ControlBase
|
from comfy.controlnet import ControlBase
|
||||||
import torch
|
import torch
|
||||||
|
from functools import partial
|
||||||
import collections
|
import collections
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.samplers
|
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@@ -130,11 +130,6 @@ def can_concat_cond(c1, c2):
|
|||||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||||
|
|
||||||
def cond_cat(c_list):
|
def cond_cat(c_list):
|
||||||
c_crossattn = []
|
|
||||||
c_concat = []
|
|
||||||
c_adm = []
|
|
||||||
crossattn_max_len = 0
|
|
||||||
|
|
||||||
temp = {}
|
temp = {}
|
||||||
for x in c_list:
|
for x in c_list:
|
||||||
for k in x:
|
for k in x:
|
||||||
@@ -149,7 +144,7 @@ def cond_cat(c_list):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
|
||||||
# need to figure out remaining unmasked area for conds
|
# need to figure out remaining unmasked area for conds
|
||||||
default_mults = []
|
default_mults = []
|
||||||
for _ in default_conds:
|
for _ in default_conds:
|
||||||
@@ -182,13 +177,13 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
|||||||
cond = default_conds[i]
|
cond = default_conds[i]
|
||||||
for x in cond:
|
for x in cond:
|
||||||
# do get_area_and_mult to get all the expected values
|
# do get_area_and_mult to get all the expected values
|
||||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
# replace p's mult with calculated mult
|
# replace p's mult with calculated mult
|
||||||
p = p._replace(mult=mult)
|
p = p._replace(mult=mult)
|
||||||
if p.hooks is not None:
|
if p.hooks is not None:
|
||||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||||
hooked_to_run.setdefault(p.hooks, list())
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
hooked_to_run[p.hooks] += [(p, i)]
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
|
||||||
@@ -219,17 +214,17 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
default_c.append(x)
|
default_c.append(x)
|
||||||
has_default_conds = True
|
has_default_conds = True
|
||||||
continue
|
continue
|
||||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
if p.hooks is not None:
|
if p.hooks is not None:
|
||||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||||
hooked_to_run.setdefault(p.hooks, list())
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
hooked_to_run[p.hooks] += [(p, i)]
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
default_conds.append(default_c)
|
default_conds.append(default_c)
|
||||||
|
|
||||||
if has_default_conds:
|
if has_default_conds:
|
||||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
|
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
model.current_patcher.prepare_state(timestep)
|
model.current_patcher.prepare_state(timestep)
|
||||||
|
|
||||||
@@ -346,7 +341,7 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
|
|||||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||||
|
|
||||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||||
"sigma": timestep, "model_options": model_options, "input": x}
|
"sigma": timestep, "model_options": model_options, "input": x}
|
||||||
cfg_result = fn(args)
|
cfg_result = fn(args)
|
||||||
|
|
||||||
@@ -380,7 +375,7 @@ class KSamplerX0Inpaint:
|
|||||||
if "denoise_mask_function" in model_options:
|
if "denoise_mask_function" in model_options:
|
||||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
||||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out = out * denoise_mask + self.latent_image * latent_mask
|
out = out * denoise_mask + self.latent_image * latent_mask
|
||||||
@@ -472,6 +467,13 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
|
|||||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||||
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||||
|
|
||||||
|
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
|
||||||
|
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
|
||||||
|
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
|
||||||
|
sigmas = adj_idxs.new_zeros(n + 1)
|
||||||
|
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
|
||||||
|
return sigmas
|
||||||
|
|
||||||
def get_mask_aabb(masks):
|
def get_mask_aabb(masks):
|
||||||
if masks.numel() == 0:
|
if masks.numel() == 0:
|
||||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||||
@@ -608,8 +610,6 @@ def pre_run_control(model, conds):
|
|||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
|
||||||
timestep_start = None
|
|
||||||
timestep_end = None
|
|
||||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||||
if 'control' in x:
|
if 'control' in x:
|
||||||
x['control'].pre_run(model, percent_to_timestep_function)
|
x['control'].pre_run(model, percent_to_timestep_function)
|
||||||
@@ -686,7 +686,7 @@ class Sampler:
|
|||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis"]
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
@@ -809,6 +809,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
|||||||
for cond in conds_to_modify:
|
for cond in conds_to_modify:
|
||||||
cond['hooks'] = hooks
|
cond['hooks'] = hooks
|
||||||
|
|
||||||
|
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
|
||||||
|
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
|
||||||
|
HookGroups that have the same reference.'''
|
||||||
|
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
|
||||||
|
# if None were registered, make sure all hooks are cleaned from conds
|
||||||
|
if registered is None:
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
kk.pop('hooks', None)
|
||||||
|
return
|
||||||
|
# find conds that contain hooks to be replaced - group by common HookGroup refs
|
||||||
|
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||||
|
if hooks is not None:
|
||||||
|
if not hooks.is_subset_of(registered):
|
||||||
|
to_replace = hook_replacement.setdefault(hooks, [])
|
||||||
|
to_replace.append(kk)
|
||||||
|
# for each hook to replace, create a new proper HookGroup and assign to all common conds
|
||||||
|
for hooks, conds_to_modify in hook_replacement.items():
|
||||||
|
new_hooks = hooks.new_with_common_hooks(registered)
|
||||||
|
if len(new_hooks) == 0:
|
||||||
|
new_hooks = None
|
||||||
|
for kk in conds_to_modify:
|
||||||
|
kk['hooks'] = new_hooks
|
||||||
|
|
||||||
|
|
||||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||||
hooks_set = set()
|
hooks_set = set()
|
||||||
@@ -818,9 +845,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
|||||||
return len(hooks_set)
|
return len(hooks_set)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||||
|
'''
|
||||||
|
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
|
||||||
|
'''
|
||||||
|
if model_options is None:
|
||||||
|
return
|
||||||
|
to_load_options = model_options.get("to_load_options", None)
|
||||||
|
if to_load_options is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
casts = []
|
||||||
|
if device is not None:
|
||||||
|
casts.append(device)
|
||||||
|
if dtype is not None:
|
||||||
|
casts.append(dtype)
|
||||||
|
# if nothing to apply, do nothing
|
||||||
|
if len(casts) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# try to call .to on patches
|
||||||
|
if "patches" in to_load_options:
|
||||||
|
patches = to_load_options["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
patch_list[i] = patch_list[i].to(cast)
|
||||||
|
if "patches_replace" in to_load_options:
|
||||||
|
patches = to_load_options["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
patch_list[k] = patch_list[k].to(cast)
|
||||||
|
# try to call .to on any wrappers/callbacks
|
||||||
|
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||||
|
for wc_name in wrappers_and_callbacks:
|
||||||
|
if wc_name in to_load_options:
|
||||||
|
wc: dict[str, list] = to_load_options[wc_name]
|
||||||
|
for wc_dict in wc.values():
|
||||||
|
for wc_list in wc_dict.values():
|
||||||
|
for i in range(len(wc_list)):
|
||||||
|
if hasattr(wc_list[i], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
wc_list[i] = wc_list[i].to(cast)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher):
|
def __init__(self, model_patcher: ModelPatcher):
|
||||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
self.model_patcher = model_patcher
|
||||||
self.model_options = model_patcher.model_options
|
self.model_options = model_patcher.model_options
|
||||||
self.original_conds = {}
|
self.original_conds = {}
|
||||||
self.cfg = 1.0
|
self.cfg = 1.0
|
||||||
@@ -847,7 +923,9 @@ class CFGGuider:
|
|||||||
|
|
||||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||||
|
|
||||||
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||||
|
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||||
|
extra_args = {"model_options": extra_model_options, "seed": seed}
|
||||||
|
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
sampler.sample,
|
sampler.sample,
|
||||||
@@ -858,7 +936,7 @@ class CFGGuider:
|
|||||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
@@ -867,6 +945,7 @@ class CFGGuider:
|
|||||||
noise = noise.to(device)
|
noise = noise.to(device)
|
||||||
latent_image = latent_image.to(device)
|
latent_image = latent_image.to(device)
|
||||||
sigmas = sigmas.to(device)
|
sigmas = sigmas.to(device)
|
||||||
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
@@ -896,6 +975,7 @@ class CFGGuider:
|
|||||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||||
|
filter_registered_hooks_on_conds(self.conds, self.model_options)
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self.outer_sample,
|
self.outer_sample,
|
||||||
self,
|
self,
|
||||||
@@ -903,6 +983,7 @@ class CFGGuider:
|
|||||||
)
|
)
|
||||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
finally:
|
finally:
|
||||||
|
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||||
self.model_options = orig_model_options
|
self.model_options = orig_model_options
|
||||||
self.model_patcher.hook_mode = orig_hook_mode
|
self.model_patcher.hook_mode = orig_hook_mode
|
||||||
self.model_patcher.restore_hook_patches()
|
self.model_patcher.restore_hook_patches()
|
||||||
@@ -918,29 +999,37 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
|
||||||
|
|
||||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
|
||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
class SchedulerHandler(NamedTuple):
|
||||||
if scheduler_name == "karras":
|
handler: Callable[..., torch.Tensor]
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
# Boolean indicates whether to call the handler like:
|
||||||
elif scheduler_name == "exponential":
|
# scheduler_function(model_sampling, steps) or
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
# scheduler_function(n, sigma_min: float, sigma_max: float)
|
||||||
elif scheduler_name == "normal":
|
use_ms: bool = True
|
||||||
sigmas = normal_scheduler(model_sampling, steps)
|
|
||||||
elif scheduler_name == "simple":
|
SCHEDULER_HANDLERS = {
|
||||||
sigmas = simple_scheduler(model_sampling, steps)
|
"normal": SchedulerHandler(normal_scheduler),
|
||||||
elif scheduler_name == "ddim_uniform":
|
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
||||||
sigmas = ddim_scheduler(model_sampling, steps)
|
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
||||||
elif scheduler_name == "sgm_uniform":
|
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
"simple": SchedulerHandler(simple_scheduler),
|
||||||
elif scheduler_name == "beta":
|
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
||||||
sigmas = beta_scheduler(model_sampling, steps)
|
"beta": SchedulerHandler(beta_scheduler),
|
||||||
elif scheduler_name == "linear_quadratic":
|
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||||
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||||
else:
|
}
|
||||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
|
||||||
return sigmas
|
|
||||||
|
def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
|
||||||
|
handler = SCHEDULER_HANDLERS.get(scheduler_name)
|
||||||
|
if handler is None:
|
||||||
|
err = f"error invalid scheduler {scheduler_name}"
|
||||||
|
logging.error(err)
|
||||||
|
raise ValueError(err)
|
||||||
|
if handler.use_ms:
|
||||||
|
return handler.handler(model_sampling, steps)
|
||||||
|
return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||||
|
|
||||||
def sampler_object(name):
|
def sampler_object(name):
|
||||||
if name == "uni_pc":
|
if name == "uni_pc":
|
||||||
|
|||||||
183
comfy/sd.py
183
comfy/sd.py
@@ -11,7 +11,9 @@ from .ldm.cascade.stage_c_coder import StageC_coder
|
|||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
import comfy.ldm.genmo.vae.model
|
import comfy.ldm.genmo.vae.model
|
||||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||||
|
import comfy.ldm.cosmos.vae
|
||||||
import yaml
|
import yaml
|
||||||
|
import math
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
@@ -26,11 +28,14 @@ import comfy.text_encoders.sd2_clip
|
|||||||
import comfy.text_encoders.sd3_clip
|
import comfy.text_encoders.sd3_clip
|
||||||
import comfy.text_encoders.sa_t5
|
import comfy.text_encoders.sa_t5
|
||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
|
import comfy.text_encoders.pixart_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
import comfy.text_encoders.long_clipl
|
import comfy.text_encoders.long_clipl
|
||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
import comfy.text_encoders.lt
|
import comfy.text_encoders.lt
|
||||||
|
import comfy.text_encoders.hunyuan_video
|
||||||
|
import comfy.text_encoders.cosmos
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@@ -108,7 +113,7 @@ class CLIP:
|
|||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@@ -256,6 +261,9 @@ class VAE:
|
|||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
self.downscale_index_formula = None
|
||||||
|
self.upscale_index_formula = None
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
@@ -306,8 +314,8 @@ class VAE:
|
|||||||
self.upscale_ratio = 4
|
self.upscale_ratio = 4
|
||||||
|
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||||
if 'quant_conv.weight' in sd:
|
if 'post_quant_conv.weight' in sd:
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||||
else:
|
else:
|
||||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||||
@@ -335,14 +343,53 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||||
|
self.upscale_index_formula = (6, 8, 8)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
||||||
|
self.downscale_index_formula = (6, 8, 8)
|
||||||
self.working_dtypes = [torch.float16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.float32]
|
||||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
|
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
||||||
|
version = 0
|
||||||
|
if tensor_conv1.shape[0] == 512:
|
||||||
|
version = 0
|
||||||
|
elif tensor_conv1.shape[0] == 1024:
|
||||||
|
version = 1
|
||||||
|
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
|
||||||
self.latent_channels = 128
|
self.latent_channels = 128
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||||
|
self.upscale_index_formula = (8, 32, 32)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||||
|
self.downscale_index_formula = (8, 32, 32)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
elif "decoder.conv_in.conv.weight" in sd:
|
||||||
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
|
ddconfig["conv3d"] = True
|
||||||
|
ddconfig["time_compress"] = 4
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
self.upscale_index_formula = (4, 8, 8)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||||
|
self.upscale_index_formula = (8, 8, 8)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 8, 8)
|
||||||
|
self.downscale_index_formula = (8, 8, 8)
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.latent_channels = 16
|
||||||
|
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
|
||||||
|
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
|
||||||
|
#TODO: these values are a bit off because this is not a standard VAE
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
@@ -370,13 +417,15 @@ class VAE:
|
|||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
|
|
||||||
dims = pixels.shape[1:-1]
|
dims = pixels.shape[1:-1]
|
||||||
for d in range(len(dims)):
|
for d in range(len(dims)):
|
||||||
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
|
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||||
x_offset = (dims[d] % self.downscale_ratio) // 2
|
x_offset = (dims[d] % downscale_ratio) // 2
|
||||||
if x != dims[d]:
|
if x != dims[d]:
|
||||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||||
return pixels
|
return pixels
|
||||||
@@ -397,11 +446,11 @@ class VAE:
|
|||||||
|
|
||||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
@@ -420,6 +469,10 @@ class VAE:
|
|||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||||
|
|
||||||
|
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||||
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
@@ -435,7 +488,7 @@ class VAE:
|
|||||||
if pixel_samples is None:
|
if pixel_samples is None:
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
if dims == 1:
|
if dims == 1:
|
||||||
@@ -450,7 +503,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
@@ -468,13 +521,20 @@ class VAE:
|
|||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
output = self.decode_tiled_(samples, **args)
|
output = self.decode_tiled_(samples, **args)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
|
if overlap_t is None:
|
||||||
|
args["overlap"] = (1, overlap, overlap)
|
||||||
|
else:
|
||||||
|
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||||
|
if tile_t is not None:
|
||||||
|
args["tile_t"] = max(2, tile_t)
|
||||||
|
|
||||||
output = self.decode_tiled_3d(samples, **args)
|
output = self.decode_tiled_3d(samples, **args)
|
||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if self.latent_dim == 3:
|
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
@@ -490,20 +550,58 @@ class VAE:
|
|||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
samples[x:x + batch_number] = out
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
if len(pixel_samples.shape) == 3:
|
if self.latent_dim == 3:
|
||||||
|
tile = 256
|
||||||
|
overlap = tile // 4
|
||||||
|
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
elif self.latent_dim == 1:
|
||||||
samples = self.encode_tiled_1d(pixel_samples)
|
samples = self.encode_tiled_1d(pixel_samples)
|
||||||
else:
|
else:
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
model_management.load_model_gpu(self.patcher)
|
dims = self.latent_dim
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
if dims == 3:
|
||||||
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
|
|
||||||
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||||
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
|
|
||||||
|
args = {}
|
||||||
|
if tile_x is not None:
|
||||||
|
args["tile_x"] = tile_x
|
||||||
|
if tile_y is not None:
|
||||||
|
args["tile_y"] = tile_y
|
||||||
|
if overlap is not None:
|
||||||
|
args["overlap"] = overlap
|
||||||
|
|
||||||
|
if dims == 1:
|
||||||
|
args.pop("tile_y")
|
||||||
|
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||||
|
elif dims == 2:
|
||||||
|
samples = self.encode_tiled_(pixel_samples, **args)
|
||||||
|
elif dims == 3:
|
||||||
|
if tile_t is not None:
|
||||||
|
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||||
|
else:
|
||||||
|
tile_t_latent = 9999
|
||||||
|
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||||
|
|
||||||
|
if overlap_t is None:
|
||||||
|
args["overlap"] = (1, overlap, overlap)
|
||||||
|
else:
|
||||||
|
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||||
|
maximum = pixel_samples.shape[2]
|
||||||
|
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||||
|
|
||||||
|
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
@@ -515,6 +613,18 @@ class VAE:
|
|||||||
except:
|
except:
|
||||||
return self.upscale_ratio
|
return self.upscale_ratio
|
||||||
|
|
||||||
|
def spacial_compression_encode(self):
|
||||||
|
try:
|
||||||
|
return self.downscale_ratio[-1]
|
||||||
|
except:
|
||||||
|
return self.downscale_ratio
|
||||||
|
|
||||||
|
def temporal_compression_decode(self):
|
||||||
|
try:
|
||||||
|
return round(self.upscale_ratio[0](8192) / 8192)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
def __init__(self, model, device="cpu"):
|
def __init__(self, model, device="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -544,6 +654,10 @@ class CLIPType(Enum):
|
|||||||
FLUX = 6
|
FLUX = 6
|
||||||
MOCHI = 7
|
MOCHI = 7
|
||||||
LTXV = 8
|
LTXV = 8
|
||||||
|
HUNYUAN_VIDEO = 9
|
||||||
|
PIXART = 10
|
||||||
|
COSMOS = 11
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
@@ -559,6 +673,8 @@ class TEModel(Enum):
|
|||||||
T5_XXL = 4
|
T5_XXL = 4
|
||||||
T5_XL = 5
|
T5_XL = 5
|
||||||
T5_BASE = 6
|
T5_BASE = 6
|
||||||
|
LLAMA3_8 = 7
|
||||||
|
T5_XXL_OLD = 8
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@@ -573,20 +689,33 @@ def detect_te_model(sd):
|
|||||||
return TEModel.T5_XXL
|
return TEModel.T5_XXL
|
||||||
elif weight.shape[-1] == 2048:
|
elif weight.shape[-1] == 2048:
|
||||||
return TEModel.T5_XL
|
return TEModel.T5_XL
|
||||||
|
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||||
|
return TEModel.T5_XXL_OLD
|
||||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
|
return TEModel.LLAMA3_8
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def t5xxl_detect(clip_data):
|
def t5xxl_detect(clip_data):
|
||||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||||
|
weight_name_old = "encoder.block.23.layer.1.DenseReluDense.wi.weight"
|
||||||
|
|
||||||
for sd in clip_data:
|
for sd in clip_data:
|
||||||
if weight_name in sd:
|
if weight_name in sd or weight_name_old in sd:
|
||||||
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def llama_detect(clip_data):
|
||||||
|
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
||||||
|
|
||||||
|
for sd in clip_data:
|
||||||
|
if weight_name in sd:
|
||||||
|
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
@@ -625,9 +754,15 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.LTXV:
|
elif clip_type == CLIPType.LTXV:
|
||||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||||
|
elif clip_type == CLIPType.PIXART:
|
||||||
|
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
|
elif te_model == TEModel.T5_XXL_OLD:
|
||||||
|
clip_target.clip = comfy.text_encoders.cosmos.te(**t5xxl_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.cosmos.CosmosT5Tokenizer
|
||||||
elif te_model == TEModel.T5_XL:
|
elif te_model == TEModel.T5_XL:
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
@@ -652,6 +787,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.FLUX:
|
elif clip_type == CLIPType.FLUX:
|
||||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
||||||
|
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||||
|
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
@@ -691,7 +829,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
config = yaml.safe_load(stream)
|
config = yaml.safe_load(stream)
|
||||||
model_config_params = config['model']['params']
|
model_config_params = config['model']['params']
|
||||||
clip_config = model_config_params['cond_stage_config']
|
clip_config = model_config_params['cond_stage_config']
|
||||||
scale_factor = model_config_params['scale_factor']
|
|
||||||
|
|
||||||
if "parameterization" in model_config_params:
|
if "parameterization" in model_config_params:
|
||||||
if model_config_params["parameterization"] == "v":
|
if model_config_params["parameterization"] == "v":
|
||||||
@@ -784,7 +921,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_model:
|
if output_model:
|
||||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
logging.info("loaded straight to GPU")
|
logging.info("loaded diffusion model directly to GPU")
|
||||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||||
|
|
||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
@@ -861,11 +998,11 @@ def load_diffusion_model(unet_path, model_options={}):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|
||||||
def load_unet_state_dict(sd, dtype=None):
|
def load_unet_state_dict(sd, dtype=None):
|
||||||
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import comfy.clip_model
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import numbers
|
import numbers
|
||||||
|
import re
|
||||||
|
|
||||||
def gen_empty_tokens(special_tokens, length):
|
def gen_empty_tokens(special_tokens, length):
|
||||||
start_token = special_tokens.get("start", None)
|
start_token = special_tokens.get("start", None)
|
||||||
@@ -36,7 +37,10 @@ class ClipTokenWeightEncoder:
|
|||||||
|
|
||||||
sections = len(to_encode)
|
sections = len(to_encode)
|
||||||
if has_weights or sections == 0:
|
if has_weights or sections == 0:
|
||||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
if hasattr(self, "gen_empty_tokens"):
|
||||||
|
to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
else:
|
||||||
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
|
||||||
o = self.encode(to_encode)
|
o = self.encode(to_encode)
|
||||||
out, pooled = o[:2]
|
out, pooled = o[:2]
|
||||||
@@ -90,8 +94,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
|
|
||||||
with open(textmodel_json_config) as f:
|
if isinstance(textmodel_json_config, dict):
|
||||||
config = json.load(f)
|
config = textmodel_json_config
|
||||||
|
else:
|
||||||
|
with open(textmodel_json_config) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
@@ -196,11 +203,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
attention_mask = None
|
attention_mask = None
|
||||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
attention_mask = torch.zeros_like(tokens)
|
||||||
end_token = self.special_tokens.get("end", -1)
|
end_token = self.special_tokens.get("end", None)
|
||||||
|
if end_token is None:
|
||||||
|
cmp_token = self.special_tokens.get("pad", -1)
|
||||||
|
else:
|
||||||
|
cmp_token = end_token
|
||||||
|
|
||||||
for x in range(attention_mask.shape[0]):
|
for x in range(attention_mask.shape[0]):
|
||||||
for y in range(attention_mask.shape[1]):
|
for y in range(attention_mask.shape[1]):
|
||||||
attention_mask[x, y] = 1
|
attention_mask[x, y] = 1
|
||||||
if tokens[x, y] == end_token:
|
if tokens[x, y] == cmp_token:
|
||||||
|
if end_token is None:
|
||||||
|
attention_mask[x, y] = 0
|
||||||
break
|
break
|
||||||
|
|
||||||
attention_mask_model = None
|
attention_mask_model = None
|
||||||
@@ -326,7 +340,6 @@ def expand_directory_list(directories):
|
|||||||
return list(dirs)
|
return list(dirs)
|
||||||
|
|
||||||
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
|
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
|
||||||
i = 0
|
|
||||||
out_list = []
|
out_list = []
|
||||||
for k in embed:
|
for k in embed:
|
||||||
if k.startswith(prefix) and k.endswith(suffix):
|
if k.startswith(prefix) and k.endswith(suffix):
|
||||||
@@ -375,14 +388,11 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
if 'weights_only' in torch.load.__code__.co_varnames:
|
try:
|
||||||
try:
|
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
except:
|
||||||
except:
|
embed_out = safe_load_embed_zip(embed_path)
|
||||||
embed_out = safe_load_embed_zip(embed_path)
|
except Exception:
|
||||||
else:
|
|
||||||
embed = torch.load(embed_path, map_location="cpu")
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -411,22 +421,31 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
|
self.end_token = None
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
|
self.tokenizer_adds_end_token = has_end_token
|
||||||
if has_start_token:
|
if has_start_token:
|
||||||
self.tokens_start = 1
|
self.tokens_start = 1
|
||||||
self.start_token = empty[0]
|
self.start_token = empty[0]
|
||||||
self.end_token = empty[1]
|
if end_token is not None:
|
||||||
|
self.end_token = end_token
|
||||||
|
else:
|
||||||
|
if has_end_token:
|
||||||
|
self.end_token = empty[1]
|
||||||
else:
|
else:
|
||||||
self.tokens_start = 0
|
self.tokens_start = 0
|
||||||
self.start_token = None
|
self.start_token = None
|
||||||
self.end_token = empty[0]
|
if end_token is not None:
|
||||||
|
self.end_token = end_token
|
||||||
|
else:
|
||||||
|
self.end_token = empty[0]
|
||||||
|
|
||||||
if pad_token is not None:
|
if pad_token is not None:
|
||||||
self.pad_token = pad_token
|
self.pad_token = pad_token
|
||||||
@@ -451,13 +470,16 @@ class SDTokenizer:
|
|||||||
Takes a potential embedding name and tries to retrieve it.
|
Takes a potential embedding name and tries to retrieve it.
|
||||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||||
'''
|
'''
|
||||||
|
split_embed = embedding_name.split()
|
||||||
|
embedding_name = split_embed[0]
|
||||||
|
leftover = ' '.join(split_embed[1:])
|
||||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||||
if embed is None:
|
if embed is None:
|
||||||
stripped = embedding_name.strip(',')
|
stripped = embedding_name.strip(',')
|
||||||
if len(stripped) < len(embedding_name):
|
if len(stripped) < len(embedding_name):
|
||||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||||
return (embed, embedding_name[len(stripped):])
|
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||||
return (embed, "")
|
return (embed, leftover)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -471,13 +493,18 @@ class SDTokenizer:
|
|||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
parsed_weights = token_weights(text, 1.0)
|
parsed_weights = token_weights(text, 1.0)
|
||||||
|
|
||||||
#tokenize words
|
# tokenize words
|
||||||
tokens = []
|
tokens = []
|
||||||
for weighted_segment, weight in parsed_weights:
|
for weighted_segment, weight in parsed_weights:
|
||||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
|
to_tokenize = unescape_important(weighted_segment)
|
||||||
|
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
|
||||||
|
to_tokenize = [split[0]]
|
||||||
|
for i in range(1, len(split)):
|
||||||
|
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
||||||
|
|
||||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||||
for word in to_tokenize:
|
for word in to_tokenize:
|
||||||
#if we find an embedding, deal with the embedding
|
# if we find an embedding, deal with the embedding
|
||||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||||
embed, leftover = self._try_get_embedding(embedding_name)
|
embed, leftover = self._try_get_embedding(embedding_name)
|
||||||
@@ -493,8 +520,11 @@ class SDTokenizer:
|
|||||||
word = leftover
|
word = leftover
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
end = 999999999999
|
||||||
|
if self.tokenizer_adds_end_token:
|
||||||
|
end = -1
|
||||||
#parse word
|
#parse word
|
||||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
|
||||||
|
|
||||||
#reshape token array to CLIP input size
|
#reshape token array to CLIP input size
|
||||||
batched_tokens = []
|
batched_tokens = []
|
||||||
@@ -505,18 +535,24 @@ class SDTokenizer:
|
|||||||
for i, t_group in enumerate(tokens):
|
for i, t_group in enumerate(tokens):
|
||||||
#determine if we're going to try and keep the tokens in a single batch
|
#determine if we're going to try and keep the tokens in a single batch
|
||||||
is_large = len(t_group) >= self.max_word_length
|
is_large = len(t_group) >= self.max_word_length
|
||||||
|
if self.end_token is not None:
|
||||||
|
has_end_token = 1
|
||||||
|
else:
|
||||||
|
has_end_token = 0
|
||||||
|
|
||||||
while len(t_group) > 0:
|
while len(t_group) > 0:
|
||||||
if len(t_group) + len(batch) > self.max_length - 1:
|
if len(t_group) + len(batch) > self.max_length - has_end_token:
|
||||||
remaining_length = self.max_length - len(batch) - 1
|
remaining_length = self.max_length - len(batch) - has_end_token
|
||||||
#break word in two and add end token
|
#break word in two and add end token
|
||||||
if is_large:
|
if is_large:
|
||||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||||
batch.append((self.end_token, 1.0, 0))
|
if self.end_token is not None:
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
t_group = t_group[remaining_length:]
|
t_group = t_group[remaining_length:]
|
||||||
#add end token and pad
|
#add end token and pad
|
||||||
else:
|
else:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
if self.end_token is not None:
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||||
#start new batch
|
#start new batch
|
||||||
@@ -529,7 +565,8 @@ class SDTokenizer:
|
|||||||
t_group = []
|
t_group = []
|
||||||
|
|
||||||
#fill last batch
|
#fill last batch
|
||||||
batch.append((self.end_token, 1.0, 0))
|
if self.end_token is not None:
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
if self.min_length is not None and len(batch) < self.min_length:
|
if self.min_length is not None and len(batch) < self.min_length:
|
||||||
|
|||||||
@@ -8,10 +8,13 @@ import comfy.text_encoders.sd2_clip
|
|||||||
import comfy.text_encoders.sd3_clip
|
import comfy.text_encoders.sd3_clip
|
||||||
import comfy.text_encoders.sa_t5
|
import comfy.text_encoders.sa_t5
|
||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
|
import comfy.text_encoders.pixart_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
import comfy.text_encoders.lt
|
import comfy.text_encoders.lt
|
||||||
|
import comfy.text_encoders.hunyuan_video
|
||||||
|
import comfy.text_encoders.cosmos
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@@ -224,7 +227,6 @@ class SDXL(supported_models_base.BASE):
|
|||||||
|
|
||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
keys_to_replace = {}
|
|
||||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
if k.startswith("clip_l"):
|
if k.startswith("clip_l"):
|
||||||
@@ -527,7 +529,6 @@ class SD3(supported_models_base.BASE):
|
|||||||
clip_l = False
|
clip_l = False
|
||||||
clip_g = False
|
clip_g = False
|
||||||
t5 = False
|
t5 = False
|
||||||
dtype_t5 = None
|
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
clip_l = True
|
clip_l = True
|
||||||
@@ -593,6 +594,39 @@ class AuraFlow(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
|
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
|
||||||
|
|
||||||
|
class PixArtAlpha(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "pixart_alpha",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"beta_schedule" : "sqrt_linear",
|
||||||
|
"linear_start" : 0.0001,
|
||||||
|
"linear_end" : 0.02,
|
||||||
|
"timesteps" : 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
|
memory_usage_factor = 0.5
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.PixArt(self, device=device)
|
||||||
|
return out.eval()
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL)
|
||||||
|
|
||||||
|
class PixArtSigma(PixArtAlpha):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "pixart_sigma",
|
||||||
|
}
|
||||||
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
class HunyuanDiT(supported_models_base.BASE):
|
class HunyuanDiT(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hydit",
|
"image_model": "hydit",
|
||||||
@@ -609,6 +643,8 @@ class HunyuanDiT(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
|
memory_usage_factor = 1.3
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@@ -740,6 +776,95 @@ class LTXV(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
|
class HunyuanVideo(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan_video",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 7.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.HunyuanVideo
|
||||||
|
|
||||||
|
memory_usage_factor = 1.8 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.HunyuanVideo(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
key_out = key_out.replace("txt_in.t_embedder.mlp.0.", "txt_in.t_embedder.in_layer.").replace("txt_in.t_embedder.mlp.2.", "txt_in.t_embedder.out_layer.")
|
||||||
|
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||||
|
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||||
|
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||||
|
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
||||||
|
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
||||||
|
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||||
|
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||||
|
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "model.model."}
|
||||||
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||||
|
|
||||||
|
class CosmosT2V(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "cosmos",
|
||||||
|
"in_channels": 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"sigma_data": 0.5,
|
||||||
|
"sigma_max": 80.0,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Cosmos1CV8x8x8
|
||||||
|
|
||||||
|
memory_usage_factor = 1.6 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.CosmosVideo(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||||
|
|
||||||
|
class CosmosI2V(CosmosT2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "cosmos",
|
||||||
|
"in_channels": 17,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
42
comfy/text_encoders/cosmos.py
Normal file
42
comfy/text_encoders/cosmos.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.t5
|
||||||
|
import os
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||||
|
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||||
|
if t5xxl_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
class CosmosT5XXL(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
|
||||||
|
|
||||||
|
|
||||||
|
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
|
class CosmosTEModel_(CosmosT5XXL):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
if dtype is None:
|
||||||
|
dtype = dtype_t5
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return CosmosTEModel_
|
||||||
112
comfy/text_encoders/hunyuan_video.py
Normal file
112
comfy/text_encoders/hunyuan_video.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
from transformers import LlamaTokenizerFast
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def llama_detect(state_dict, prefix=""):
|
||||||
|
out = {}
|
||||||
|
t5_key = "{}model.norm.weight".format(prefix)
|
||||||
|
if t5_key in state_dict:
|
||||||
|
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
|
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||||
|
if scaled_fp8_key in state_dict:
|
||||||
|
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
|
||||||
|
|
||||||
|
class LLAMAModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||||
|
if llama_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoTokenizer:
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
|
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
|
||||||
|
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
|
out = {}
|
||||||
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|
||||||
|
llama_text = "{}{}".format(self.llama_template, text)
|
||||||
|
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return self.clip_l.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoClipModel(torch.nn.Module):
|
||||||
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__()
|
||||||
|
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||||
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
|
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
||||||
|
self.dtypes = set([dtype, dtype_llama])
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.clip_l.set_clip_options(options)
|
||||||
|
self.llama.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.clip_l.reset_clip_options()
|
||||||
|
self.llama.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
|
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||||
|
|
||||||
|
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||||
|
|
||||||
|
template_end = 0
|
||||||
|
for i, v in enumerate(token_weight_pairs_llama[0]):
|
||||||
|
if v[0] == 128007: # <|end_header_id|>
|
||||||
|
template_end = i
|
||||||
|
|
||||||
|
if llama_out.shape[1] > (template_end + 2):
|
||||||
|
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
|
||||||
|
template_end += 2
|
||||||
|
llama_out = llama_out[:, template_end:]
|
||||||
|
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
|
||||||
|
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||||
|
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||||
|
|
||||||
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
|
return llama_out, l_pooled, llama_extra_out
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.llama.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
|
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
||||||
|
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||||
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return HunyuanVideoClipModel_
|
||||||
226
comfy/text_encoders/llama.py
Normal file
226
comfy/text_encoders/llama.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Llama2Config:
|
||||||
|
vocab_size: int = 128320
|
||||||
|
hidden_size: int = 4096
|
||||||
|
intermediate_size: int = 14336
|
||||||
|
num_hidden_layers: int = 32
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 8192
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 500000.0
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
|
||||||
|
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
||||||
|
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
||||||
|
|
||||||
|
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
|
||||||
|
|
||||||
|
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
return (cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
cos = freqs_cis[0].unsqueeze(1)
|
||||||
|
sin = freqs_cis[1].unsqueeze(1)
|
||||||
|
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||||
|
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.num_kv_heads = config.num_key_value_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
ops = ops or nn
|
||||||
|
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
|
optimized_attention=None,
|
||||||
|
):
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
xq = self.q_proj(hidden_states)
|
||||||
|
xk = self.k_proj(hidden_states)
|
||||||
|
xv = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||||
|
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
|
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
|
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
|
|
||||||
|
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||||
|
super().__init__()
|
||||||
|
ops = ops or nn
|
||||||
|
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
|
optimized_attention=None,
|
||||||
|
):
|
||||||
|
# Self Attention
|
||||||
|
residual = x
|
||||||
|
x = self.input_layernorm(x)
|
||||||
|
x = self.self_attn(
|
||||||
|
hidden_states=x,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
freqs_cis=freqs_cis,
|
||||||
|
optimized_attention=optimized_attention,
|
||||||
|
)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
residual = x
|
||||||
|
x = self.post_attention_layernorm(x)
|
||||||
|
x = self.mlp(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Llama2_(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = ops.Embedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||||
|
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||||
|
x = self.embed_tokens(x, out_dtype=dtype)
|
||||||
|
|
||||||
|
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
|
||||||
|
x.shape[1],
|
||||||
|
self.config.rope_theta,
|
||||||
|
device=x.device)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
|
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||||
|
|
||||||
|
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||||
|
if mask is not None:
|
||||||
|
mask += causal_mask
|
||||||
|
else:
|
||||||
|
mask = causal_mask
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||||
|
|
||||||
|
intermediate = None
|
||||||
|
if intermediate_output is not None:
|
||||||
|
if intermediate_output < 0:
|
||||||
|
intermediate_output = len(self.layers) + intermediate_output
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
x = layer(
|
||||||
|
x=x,
|
||||||
|
attention_mask=mask,
|
||||||
|
freqs_cis=freqs_cis,
|
||||||
|
optimized_attention=optimized_attention,
|
||||||
|
)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
if intermediate is not None and final_layer_norm_intermediate:
|
||||||
|
intermediate = self.norm(intermediate)
|
||||||
|
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class Llama2(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Llama2Config(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, embeddings):
|
||||||
|
self.model.embed_tokens = embeddings
|
||||||
|
|
||||||
|
def forward(self, input_ids, *args, **kwargs):
|
||||||
|
return self.model(input_ids, *args, **kwargs)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user