Compare commits
127 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1c2301697 | ||
|
|
8d31a6632f | ||
|
|
b643eae08b | ||
|
|
baa6b4dc36 | ||
|
|
d4aeefc297 | ||
|
|
587e7ca654 | ||
|
|
c90459eba0 | ||
|
|
04278afb10 | ||
|
|
935ae153e1 | ||
|
|
e91662e784 | ||
|
|
63fafaef45 | ||
|
|
ec28cd9136 | ||
|
|
6eb5d64522 | ||
|
|
10a79e9898 | ||
|
|
ea3f39bd69 | ||
|
|
b33cd61070 | ||
|
|
34eda0f853 | ||
|
|
d31e226650 | ||
|
|
b79fd7d92c | ||
|
|
38c22e631a | ||
|
|
6bbdcd28ae | ||
|
|
ab130001a8 | ||
|
|
ca4b8f30e0 | ||
|
|
70b84058c1 | ||
|
|
2ca8f6e23d | ||
|
|
7985ff88b9 | ||
|
|
c6812947e9 | ||
|
|
9230f65823 | ||
|
|
6ab1e6fd4a | ||
|
|
07dcbc3a3e | ||
|
|
8ae23d8e80 | ||
|
|
7df42b9a23 | ||
|
|
5d8bbb7281 | ||
|
|
2c1d2375d6 | ||
|
|
64ccb3c7e3 | ||
|
|
9465b23432 | ||
|
|
bb4416dd5b | ||
|
|
c0b0da264b | ||
|
|
c26ca27207 | ||
|
|
7c6bb84016 | ||
|
|
c54d3ed5e6 | ||
|
|
c7ee4b37a1 | ||
|
|
7b70b266d8 | ||
|
|
8f60d093ba | ||
|
|
dafbe321d2 | ||
|
|
5f84ea63e8 | ||
|
|
843a7ff70c | ||
|
|
a60620dcea | ||
|
|
015f73dc49 | ||
|
|
904bf58e7d | ||
|
|
5f50263088 | ||
|
|
5e806f555d | ||
|
|
f07e5bb522 | ||
|
|
03ec517afb | ||
|
|
f257fc999f | ||
|
|
bb50e69839 | ||
|
|
510f3438c1 | ||
|
|
ea63b1c092 | ||
|
|
9953f22fce | ||
|
|
d1a6bd6845 | ||
|
|
83dbac28eb | ||
|
|
538cb068bc | ||
|
|
1b3eee672c | ||
|
|
5a69f84c3c | ||
|
|
9eee470244 | ||
|
|
045377ea89 | ||
|
|
4d341b78e8 | ||
|
|
6138f92084 | ||
|
|
be0726c1ed | ||
|
|
766ae119a8 | ||
|
|
fc90ceb6ba | ||
|
|
4506ddc86a | ||
|
|
20ace7c853 | ||
|
|
b29b3b86c5 | ||
|
|
22ec02afc0 | ||
|
|
39f114c44b | ||
|
|
6730f3e1a3 | ||
|
|
73332160c8 | ||
|
|
2622c55aff | ||
|
|
1beb348ee2 | ||
|
|
9aa39e743c | ||
|
|
d31df04c8a | ||
|
|
e68763f40c | ||
|
|
310ad09258 | ||
|
|
4f7a3cb6fb | ||
|
|
bb222ceddb | ||
|
|
14af129c55 | ||
|
|
fca42836f2 | ||
|
|
858d51f91a | ||
|
|
cd5017c1c9 | ||
|
|
83f343146a | ||
|
|
b021cf67c7 | ||
|
|
1770fc77ed | ||
|
|
05a9f3faa1 | ||
|
|
86c5970ac0 | ||
|
|
bfc214f434 | ||
|
|
3f5939add6 | ||
|
|
5960f946a9 | ||
|
|
5cfe38f41c | ||
|
|
0f9c2a7822 | ||
|
|
153d0a8142 | ||
|
|
ab4dd19b91 | ||
|
|
f1d6cef71c | ||
|
|
33fb282d5c | ||
|
|
50bf66e5c4 | ||
|
|
e60e19b175 | ||
|
|
a5af64d3ce | ||
|
|
3e52e0364c | ||
|
|
34608de2e9 | ||
|
|
39fb74c5bd | ||
|
|
74e124f4d7 | ||
|
|
a562c17e8a | ||
|
|
5942c17d55 | ||
|
|
c032b11e07 | ||
|
|
b8ffb2937f | ||
|
|
ce37c11164 | ||
|
|
b5c3906b38 | ||
|
|
5d43e75e5b | ||
|
|
517f4a94e4 | ||
|
|
52a471c5c7 | ||
|
|
ad76574cb8 | ||
|
|
9acfe4df41 | ||
|
|
9829b013ea | ||
|
|
5c69cde037 | ||
|
|
e9589d6d92 | ||
|
|
0d82a798a5 | ||
|
|
925fff26fd |
@@ -75,6 +75,25 @@ else:
|
|||||||
print("pulling latest changes")
|
print("pulling latest changes")
|
||||||
pull(repo)
|
pull(repo)
|
||||||
|
|
||||||
|
if "--stable" in sys.argv:
|
||||||
|
def latest_tag(repo):
|
||||||
|
versions = []
|
||||||
|
for k in repo.references:
|
||||||
|
try:
|
||||||
|
prefix = "refs/tags/v"
|
||||||
|
if k.startswith(prefix):
|
||||||
|
version = list(map(int, k[len(prefix):].split(".")))
|
||||||
|
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
versions.sort()
|
||||||
|
if len(versions) > 0:
|
||||||
|
return versions[-1][1]
|
||||||
|
return None
|
||||||
|
latest_tag = latest_tag(repo)
|
||||||
|
if latest_tag is not None:
|
||||||
|
repo.checkout(latest_tag)
|
||||||
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|
||||||
self_update = True
|
self_update = True
|
||||||
@@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
|
|||||||
shutil.copy(repo_req_path, req_path)
|
shutil.copy(repo_req_path, req_path)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
|
||||||
|
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not file_size(stable_update_script_to) > 10:
|
||||||
|
shutil.copy(stable_update_script, stable_update_script_to)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|||||||
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
@@ -0,0 +1,8 @@
|
|||||||
|
@echo off
|
||||||
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
|
||||||
|
if exist update_new.py (
|
||||||
|
move /y update_new.py update.py
|
||||||
|
echo Running updater again since it got updated.
|
||||||
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
|
||||||
|
)
|
||||||
|
if "%~1"=="" pause
|
||||||
@@ -14,7 +14,7 @@ run_cpu.bat
|
|||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED WAY TO UPDATE:
|
RECOMMENDED WAY TO UPDATE:
|
||||||
|
|||||||
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
||||||
|
pause
|
||||||
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
/web/assets/** linguist-generated
|
||||||
|
/web/** linguist-vendored
|
||||||
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,8 @@
|
|||||||
blank_issues_enabled: true
|
blank_issues_enabled: true
|
||||||
contact_links:
|
contact_links:
|
||||||
|
- name: ComfyUI Frontend Issues
|
||||||
|
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
|
||||||
|
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
|
||||||
- name: ComfyUI Matrix Space
|
- name: ComfyUI Matrix Space
|
||||||
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||||
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
||||||
|
|||||||
16
.github/workflows/pullrequest-ci-run.yml
vendored
16
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -35,3 +35,19 @@ jobs:
|
|||||||
torch_version: ${{ matrix.torch_version }}
|
torch_version: ${{ matrix.torch_version }}
|
||||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
comfyui_flags: ${{ matrix.flags }}
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
use_prior_commit: 'true'
|
||||||
|
comment:
|
||||||
|
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@v6
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
github.rest.issues.createComment({
|
||||||
|
issue_number: context.issue.number,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
||||||
|
})
|
||||||
|
|||||||
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: 'Close stale issues'
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
# Run daily at 430 am PT
|
||||||
|
- cron: '30 11 * * *'
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||||
|
days-before-stale: 30
|
||||||
|
days-before-close: 7
|
||||||
|
stale-issue-label: 'Stale'
|
||||||
|
only-labels: 'User Support'
|
||||||
|
exempt-all-assignees: true
|
||||||
|
exempt-all-milestones: true
|
||||||
@@ -1,10 +1,4 @@
|
|||||||
# This is a temporary action during frontend TS migration.
|
name: Test server launches without errors
|
||||||
# This file should be removed after TS migration is completed.
|
|
||||||
# The browser test is here to ensure TS repo is working the same way as the
|
|
||||||
# current JS code.
|
|
||||||
# If you are adding UI feature, please sync your changes to the TS repo:
|
|
||||||
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
|
|
||||||
name: Playwright Browser Tests CI
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -21,15 +15,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
repository: "comfyanonymous/ComfyUI"
|
repository: "comfyanonymous/ComfyUI"
|
||||||
path: "ComfyUI"
|
path: "ComfyUI"
|
||||||
- name: Checkout ComfyUI_frontend
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
repository: "huchenlei/ComfyUI_frontend"
|
|
||||||
path: "ComfyUI_frontend"
|
|
||||||
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
|
|
||||||
- uses: actions/setup-node@v3
|
|
||||||
with:
|
|
||||||
node-version: lts/*
|
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.8'
|
python-version: '3.8'
|
||||||
@@ -45,16 +30,6 @@ jobs:
|
|||||||
python main.py --cpu 2>&1 | tee console_output.log &
|
python main.py --cpu 2>&1 | tee console_output.log &
|
||||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||||
working-directory: ComfyUI
|
working-directory: ComfyUI
|
||||||
- name: Install ComfyUI_frontend dependencies
|
|
||||||
run: |
|
|
||||||
npm ci
|
|
||||||
working-directory: ComfyUI_frontend
|
|
||||||
- name: Install Playwright Browsers
|
|
||||||
run: npx playwright install --with-deps
|
|
||||||
working-directory: ComfyUI_frontend
|
|
||||||
- name: Run Playwright tests
|
|
||||||
run: npx playwright test
|
|
||||||
working-directory: ComfyUI_frontend
|
|
||||||
- name: Check for unhandled exceptions in server log
|
- name: Check for unhandled exceptions in server log
|
||||||
run: |
|
run: |
|
||||||
if grep -qE "Exception|Error" console_output.log; then
|
if grep -qE "Exception|Error" console_output.log; then
|
||||||
@@ -62,12 +37,6 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
working-directory: ComfyUI
|
working-directory: ComfyUI
|
||||||
- uses: actions/upload-artifact@v4
|
|
||||||
if: always()
|
|
||||||
with:
|
|
||||||
name: playwright-report
|
|
||||||
path: ComfyUI_frontend/playwright-report/
|
|
||||||
retention-days: 30
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: always()
|
if: always()
|
||||||
with:
|
with:
|
||||||
30
.github/workflows/test-ui.yaml
vendored
30
.github/workflows/test-ui.yaml
vendored
@@ -1,30 +0,0 @@
|
|||||||
name: Tests CI
|
|
||||||
|
|
||||||
on: [push, pull_request]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: actions/setup-node@v3
|
|
||||||
with:
|
|
||||||
node-version: 18
|
|
||||||
- uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
- name: Install requirements
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
pip install -r requirements.txt
|
|
||||||
- name: Run Tests
|
|
||||||
run: |
|
|
||||||
npm ci
|
|
||||||
npm run test:generate
|
|
||||||
npm test -- --verbose
|
|
||||||
working-directory: ./tests-ui
|
|
||||||
- name: Run Unit Tests
|
|
||||||
run: |
|
|
||||||
pip install -r tests-unit/requirements.txt
|
|
||||||
python -m pytest tests-unit
|
|
||||||
@@ -67,6 +67,7 @@ jobs:
|
|||||||
mkdir update
|
mkdir update
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
|
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||||
|
|
||||||
echo "call update_comfyui.bat nopause
|
echo "call update_comfyui.bat nopause
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -18,4 +18,5 @@ venv/
|
|||||||
/tests-ui/data/object_info.json
|
/tests-ui/data/object_info.json
|
||||||
/user/
|
/user/
|
||||||
*.log
|
*.log
|
||||||
web_custom_versions/
|
web_custom_versions/
|
||||||
|
.DS_Store
|
||||||
|
|||||||
89
README.md
89
README.md
@@ -1,8 +1,35 @@
|
|||||||
ComfyUI
|
<div align="center">
|
||||||
=======
|
|
||||||
The most powerful and modular stable diffusion GUI and backend.
|
# ComfyUI
|
||||||
-----------
|
**The most powerful and modular diffusion model GUI and backend.**
|
||||||
|
|
||||||
|
|
||||||
|
[![Website][website-shield]][website-url]
|
||||||
|
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||||
|
[![Matrix][matrix-shield]][matrix-url]
|
||||||
|
<br>
|
||||||
|
[![][github-release-shield]][github-release-link]
|
||||||
|
[![][github-release-date-shield]][github-release-link]
|
||||||
|
[![][github-downloads-shield]][github-downloads-link]
|
||||||
|
[![][github-downloads-latest-shield]][github-downloads-link]
|
||||||
|
|
||||||
|
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
|
||||||
|
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||||
|
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
|
||||||
|
[website-url]: https://www.comfy.org/
|
||||||
|
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||||
|
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||||
|
[discord-url]: https://www.comfy.org/discord
|
||||||
|
|
||||||
|
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||||
|
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||||
|
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
|
||||||
|
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
|
||||||
|
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||||
|
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||||
|
|
||||||

|

|
||||||
|
</div>
|
||||||
|
|
||||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
@@ -48,6 +75,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||||
| Ctrl + Enter | Queue up current graph for generation |
|
| Ctrl + Enter | Queue up current graph for generation |
|
||||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||||
|
| Ctrl + Alt + Enter | Cancel current generation |
|
||||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||||
| Ctrl + S | Save workflow |
|
| Ctrl + S | Save workflow |
|
||||||
| Ctrl + O | Load workflow |
|
| Ctrl + O | Load workflow |
|
||||||
@@ -70,6 +98,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| H | Toggle visibility of history |
|
| H | Toggle visibility of history |
|
||||||
| R | Refresh graph |
|
| R | Refresh graph |
|
||||||
| Double-Click LMB | Open node quick search palette |
|
| Double-Click LMB | Open node quick search palette |
|
||||||
|
| Shift + Drag | Move multiple wires at once |
|
||||||
|
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
||||||
|
|
||||||
Ctrl can also be replaced with Cmd instead for macOS users
|
Ctrl can also be replaced with Cmd instead for macOS users
|
||||||
|
|
||||||
@@ -105,17 +135,17 @@ Put your VAE in: models/vae
|
|||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
@@ -200,7 +230,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||||
|
|
||||||
## How to use TLS/SSL?
|
## How to use TLS/SSL?
|
||||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||||
@@ -216,6 +246,47 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
|||||||
|
|
||||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||||
|
|
||||||
|
## Frontend Development
|
||||||
|
|
||||||
|
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||||
|
|
||||||
|
### Reporting Issues and Requesting Features
|
||||||
|
|
||||||
|
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
|
||||||
|
|
||||||
|
### Using the Latest Frontend
|
||||||
|
|
||||||
|
The new frontend is now the default for ComfyUI. However, please note:
|
||||||
|
|
||||||
|
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||||
|
2. Daily releases are available in the separate frontend repository.
|
||||||
|
|
||||||
|
To use the most up-to-date frontend version:
|
||||||
|
|
||||||
|
1. For the latest daily release, launch ComfyUI with this command line argument:
|
||||||
|
|
||||||
|
```
|
||||||
|
--front-end-version Comfy-Org/ComfyUI_frontend@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
2. For a specific version, replace `latest` with the desired version number:
|
||||||
|
|
||||||
|
```
|
||||||
|
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||||
|
```
|
||||||
|
|
||||||
|
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||||
|
|
||||||
|
### Accessing the Legacy Frontend
|
||||||
|
|
||||||
|
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||||
|
|
||||||
|
```
|
||||||
|
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||||
|
|
||||||
# QA
|
# QA
|
||||||
|
|
||||||
### Which GPU should I buy for this?
|
### Which GPU should I buy for this?
|
||||||
|
|||||||
0
api_server/__init__.py
Normal file
0
api_server/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
3
api_server/routes/internal/README.md
Normal file
3
api_server/routes/internal/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# ComfyUI Internal Routes
|
||||||
|
|
||||||
|
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
||||||
0
api_server/routes/internal/__init__.py
Normal file
0
api_server/routes/internal/__init__.py
Normal file
44
api_server/routes/internal/internal_routes.py
Normal file
44
api_server/routes/internal/internal_routes.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from aiohttp import web
|
||||||
|
from typing import Optional
|
||||||
|
from folder_paths import models_dir, user_directory, output_directory
|
||||||
|
from api_server.services.file_service import FileService
|
||||||
|
import app.logger
|
||||||
|
|
||||||
|
class InternalRoutes:
|
||||||
|
'''
|
||||||
|
The top level web router for internal routes: /internal/*
|
||||||
|
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||||
|
Check README.md for more information.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self):
|
||||||
|
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||||
|
self._app: Optional[web.Application] = None
|
||||||
|
self.file_service = FileService({
|
||||||
|
"models": models_dir,
|
||||||
|
"user": user_directory,
|
||||||
|
"output": output_directory
|
||||||
|
})
|
||||||
|
|
||||||
|
def setup_routes(self):
|
||||||
|
@self.routes.get('/files')
|
||||||
|
async def list_files(request):
|
||||||
|
directory_key = request.query.get('directory', '')
|
||||||
|
try:
|
||||||
|
file_list = self.file_service.list_files(directory_key)
|
||||||
|
return web.json_response({"files": file_list})
|
||||||
|
except ValueError as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=400)
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
@self.routes.get('/logs')
|
||||||
|
async def get_logs(request):
|
||||||
|
return web.json_response(app.logger.get_logs())
|
||||||
|
|
||||||
|
def get_app(self):
|
||||||
|
if self._app is None:
|
||||||
|
self._app = web.Application()
|
||||||
|
self.setup_routes()
|
||||||
|
self._app.add_routes(self.routes)
|
||||||
|
return self._app
|
||||||
0
api_server/services/__init__.py
Normal file
0
api_server/services/__init__.py
Normal file
13
api_server/services/file_service.py
Normal file
13
api_server/services/file_service.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||||
|
|
||||||
|
class FileService:
|
||||||
|
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||||
|
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||||
|
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||||
|
|
||||||
|
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||||
|
if directory_key not in self.allowed_directories:
|
||||||
|
raise ValueError("Invalid directory key")
|
||||||
|
directory_path: str = self.allowed_directories[directory_key]
|
||||||
|
return self.file_system_ops.walk_directory(directory_path)
|
||||||
42
api_server/utils/file_operations.py
Normal file
42
api_server/utils/file_operations.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Union, TypedDict, Literal
|
||||||
|
from typing_extensions import TypeGuard
|
||||||
|
class FileInfo(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
type: Literal["file"]
|
||||||
|
size: int
|
||||||
|
|
||||||
|
class DirectoryInfo(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
type: Literal["directory"]
|
||||||
|
|
||||||
|
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
||||||
|
|
||||||
|
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
||||||
|
return item["type"] == "file"
|
||||||
|
|
||||||
|
class FileSystemOperations:
|
||||||
|
@staticmethod
|
||||||
|
def walk_directory(directory: str) -> List[FileSystemItem]:
|
||||||
|
file_list: List[FileSystemItem] = []
|
||||||
|
for root, dirs, files in os.walk(directory):
|
||||||
|
for name in files:
|
||||||
|
file_path = os.path.join(root, name)
|
||||||
|
relative_path = os.path.relpath(file_path, directory)
|
||||||
|
file_list.append({
|
||||||
|
"name": name,
|
||||||
|
"path": relative_path,
|
||||||
|
"type": "file",
|
||||||
|
"size": os.path.getsize(file_path)
|
||||||
|
})
|
||||||
|
for name in dirs:
|
||||||
|
dir_path = os.path.join(root, name)
|
||||||
|
relative_path = os.path.relpath(dir_path, directory)
|
||||||
|
file_list.append({
|
||||||
|
"name": name,
|
||||||
|
"path": relative_path,
|
||||||
|
"type": "directory"
|
||||||
|
})
|
||||||
|
return file_list
|
||||||
@@ -8,7 +8,7 @@ import zipfile
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
@@ -132,12 +132,13 @@ class FrontendManager:
|
|||||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Initializes the frontend for the specified version.
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version_string (str): The version string.
|
version_string (str): The version string.
|
||||||
|
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the initialized frontend.
|
str: The path to the initialized frontend.
|
||||||
@@ -150,7 +151,7 @@ class FrontendManager:
|
|||||||
return cls.DEFAULT_FRONTEND_PATH
|
return cls.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
provider = FrontEndProvider(repo_owner, repo_name)
|
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||||
release = provider.get_release(version)
|
release = provider.get_release(version)
|
||||||
|
|
||||||
semantic_version = release["tag_name"].lstrip("v")
|
semantic_version = release["tag_name"].lstrip("v")
|
||||||
@@ -158,15 +159,21 @@ class FrontendManager:
|
|||||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||||
)
|
)
|
||||||
if not os.path.exists(web_root):
|
if not os.path.exists(web_root):
|
||||||
os.makedirs(web_root, exist_ok=True)
|
try:
|
||||||
logging.info(
|
os.makedirs(web_root, exist_ok=True)
|
||||||
"Downloading frontend(%s) version(%s) to (%s)",
|
logging.info(
|
||||||
provider.folder_name,
|
"Downloading frontend(%s) version(%s) to (%s)",
|
||||||
semantic_version,
|
provider.folder_name,
|
||||||
web_root,
|
semantic_version,
|
||||||
)
|
web_root,
|
||||||
logging.debug(release)
|
)
|
||||||
download_release_asset_zip(release, destination_path=web_root)
|
logging.debug(release)
|
||||||
|
download_release_asset_zip(release, destination_path=web_root)
|
||||||
|
finally:
|
||||||
|
# Clean up the directory if it is empty, i.e. the download failed
|
||||||
|
if not os.listdir(web_root):
|
||||||
|
os.rmdir(web_root)
|
||||||
|
|
||||||
return web_root
|
return web_root
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
31
app/logger.py
Normal file
31
app/logger.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
from logging.handlers import MemoryHandler
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
logs = None
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logs():
|
||||||
|
return "\n".join([formatter.format(x) for x in logs])
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(verbose: bool = False, capacity: int = 300):
|
||||||
|
global logs
|
||||||
|
if logs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup default global logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# Create a memory handler with a deque as its buffer
|
||||||
|
logs = deque(maxlen=capacity)
|
||||||
|
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||||
|
memory_handler.buffer = logs
|
||||||
|
memory_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(memory_handler)
|
||||||
@@ -92,6 +92,10 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
@@ -112,10 +116,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
|||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||||
|
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
@@ -171,10 +179,3 @@ if args.windows_standalone_build:
|
|||||||
|
|
||||||
if args.disable_auto_launch:
|
if args.disable_auto_launch:
|
||||||
args.auto_launch = False
|
args.auto_launch = False
|
||||||
|
|
||||||
import logging
|
|
||||||
logging_level = logging.INFO
|
|
||||||
if args.verbose:
|
|
||||||
logging_level = logging.DEBUG
|
|
||||||
|
|
||||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
|
||||||
|
|||||||
@@ -88,10 +88,11 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
heads = config_dict["num_attention_heads"]
|
heads = config_dict["num_attention_heads"]
|
||||||
intermediate_size = config_dict["intermediate_size"]
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
intermediate_activation = config_dict["hidden_act"]
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
|
num_positions = config_dict["max_position_embeddings"]
|
||||||
self.eos_token_id = config_dict["eos_token_id"]
|
self.eos_token_id = config_dict["eos_token_id"]
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
|
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
|
||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
@@ -123,7 +124,6 @@ class CLIPTextModel(torch.nn.Module):
|
|||||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||||
embed_dim = config_dict["hidden_size"]
|
embed_dim = config_dict["hidden_size"]
|
||||||
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ import comfy.t2i_adapter.adapter
|
|||||||
import comfy.ldm.cascade.controlnet
|
import comfy.ldm.cascade.controlnet
|
||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
|
import comfy.ldm.flux.controlnet
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
@@ -146,7 +148,7 @@ class ControlBase:
|
|||||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||||
x *= (self.strength ** float(len(control_output) - i))
|
x *= (self.strength ** float(len(control_output) - i))
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if output_dtype is not None and x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
out[key].append(x)
|
out[key].append(x)
|
||||||
@@ -204,7 +206,6 @@ class ControlNet(ControlBase):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
@@ -234,7 +235,7 @@ class ControlNet(ControlBase):
|
|||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype=None)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
@@ -389,7 +390,8 @@ def controlnet_config(sd):
|
|||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
@@ -403,12 +405,12 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
@@ -416,10 +418,11 @@ def load_controlnet_mmdit(sd):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data):
|
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
def load_controlnet_hunyuandit(controlnet_data):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
||||||
|
|
||||||
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SDXL()
|
latent_format = comfy.latent_formats.SDXL()
|
||||||
@@ -427,6 +430,33 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_xlabs(sd):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_instantx(sd):
|
||||||
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
|
num_union_modes = 0
|
||||||
|
union_cnet = "controlnet_mode_embedder.weight"
|
||||||
|
if union_cnet in new_sd:
|
||||||
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Flux()
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
@@ -489,7 +519,12 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
|
return load_controlnet_flux_xlabs(controlnet_data)
|
||||||
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
|
return load_controlnet_mmdit(controlnet_data)
|
||||||
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
|
return load_controlnet_flux_instantx(controlnet_data)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@@ -521,6 +556,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if manual_cast_dtype is not None:
|
if manual_cast_dtype is not None:
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||||
controlnet_config["dtype"] = unet_dtype
|
controlnet_config["dtype"] = unet_dtype
|
||||||
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
|||||||
if text_encoder2_path is not None:
|
if text_encoder2_path is not None:
|
||||||
text_encoder_paths.append(text_encoder2_path)
|
text_encoder_paths.append(text_encoder2_path)
|
||||||
|
|
||||||
unet = comfy.sd.load_unet(unet_path)
|
unet = comfy.sd.load_diffusion_model(unet_path)
|
||||||
|
|
||||||
clip = None
|
clip = None
|
||||||
if output_clip:
|
if output_clip:
|
||||||
|
|||||||
62
comfy/float.py
Normal file
62
comfy/float.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||||
|
mantissa_scaled = torch.where(
|
||||||
|
normal_mask,
|
||||||
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||||
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||||
|
)
|
||||||
|
|
||||||
|
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||||
|
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||||
|
|
||||||
|
#Not 100% sure about this
|
||||||
|
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||||
|
elif dtype == torch.float8_e5m2:
|
||||||
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported dtype")
|
||||||
|
|
||||||
|
x = x.half()
|
||||||
|
sign = torch.sign(x)
|
||||||
|
abs_x = x.abs()
|
||||||
|
sign = torch.where(abs_x == 0, 0, sign)
|
||||||
|
|
||||||
|
# Combine exponent calculation and clamping
|
||||||
|
exponent = torch.clamp(
|
||||||
|
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||||
|
0, 2**EXPONENT_BITS - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine mantissa calculation and rounding
|
||||||
|
normal_mask = ~(exponent == 0)
|
||||||
|
|
||||||
|
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||||
|
|
||||||
|
sign *= torch.where(
|
||||||
|
normal_mask,
|
||||||
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
|
)
|
||||||
|
del abs_x
|
||||||
|
|
||||||
|
return sign.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def stochastic_rounding(value, dtype, seed=0):
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return value.to(dtype=torch.float32)
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return value.to(dtype=torch.float16)
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
return value.to(dtype=torch.bfloat16)
|
||||||
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
|
generator = torch.Generator(device=value.device)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
|
||||||
|
|
||||||
|
return value.to(dtype=dtype)
|
||||||
@@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
|||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
import comfy.model_sampling
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
return torch.cat([x, x.new_zeros([1])])
|
||||||
@@ -509,6 +510,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||||
|
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||||
|
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
@@ -541,6 +545,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||||
|
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||||
|
|
||||||
|
# logged_x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||||
|
sigma_down = sigmas[i+1] * downstep_ratio
|
||||||
|
alpha_ip1 = 1 - sigmas[i+1]
|
||||||
|
alpha_down = 1 - sigma_down
|
||||||
|
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||||
|
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
# Euler method
|
||||||
|
d = to_d(x, sigmas[i], denoised)
|
||||||
|
dt = sigma_down - sigmas[i]
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2S)
|
||||||
|
if sigmas[i] == 1.0:
|
||||||
|
sigma_s = 0.9999
|
||||||
|
else:
|
||||||
|
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
|
||||||
|
r = 1 / 2
|
||||||
|
h = t_down - t_i
|
||||||
|
s = t_i + r * h
|
||||||
|
sigma_s = sigma_fn(s)
|
||||||
|
# sigma_s = sigmas[i+1]
|
||||||
|
sigma_s_i_ratio = sigma_s / sigmas[i]
|
||||||
|
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
|
||||||
|
D_i = model(u, sigma_s * s_in, **extra_args)
|
||||||
|
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||||
|
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
|
||||||
|
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
||||||
|
# Noise addition
|
||||||
|
if sigmas[i + 1] > 0 and eta > 0:
|
||||||
|
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||||
|
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ class StableAudio1(LatentFormat):
|
|||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
|
|
||||||
class Flux(SD3):
|
class Flux(SD3):
|
||||||
|
latent_channels = 16
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.3611
|
self.scale_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
@@ -162,6 +163,7 @@ class Flux(SD3):
|
|||||||
[-0.0005, -0.0530, -0.0020],
|
[-0.0005, -0.0530, -0.0020],
|
||||||
[-0.1273, -0.0932, -0.0680]
|
[-0.1273, -0.0932, -0.0680]
|
||||||
]
|
]
|
||||||
|
self.taesd_decoder_name = "taef1_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
return (latent - self.shift_factor) * self.scale_factor
|
return (latent - self.shift_factor) * self.scale_factor
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
@@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
def rms_norm(x, weight, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|||||||
151
comfy/ldm/flux/controlnet.py
Normal file
151
comfy/ldm/flux/controlnet.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||||
|
MLPEmbedder, SingleStreamBlock,
|
||||||
|
timestep_embedding)
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFlux(Flux):
|
||||||
|
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
self.main_model_double = 19
|
||||||
|
self.main_model_single = 38
|
||||||
|
# add ControlNet blocks
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth_single_blocks):
|
||||||
|
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
self.num_union_modes = num_union_modes
|
||||||
|
self.controlnet_mode_embedder = None
|
||||||
|
if self.num_union_modes > 0:
|
||||||
|
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.latent_input = latent_input
|
||||||
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
if not self.latent_input:
|
||||||
|
self.input_hint_block = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
control_type: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
if not self.latent_input:
|
||||||
|
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||||
|
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
|
||||||
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
|
img = img + controlnet_cond
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||||
|
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||||
|
txt = torch.cat([control_cond, txt], dim=1)
|
||||||
|
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
controlnet_double = ()
|
||||||
|
|
||||||
|
for i in range(len(self.double_blocks)):
|
||||||
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
controlnet_single = ()
|
||||||
|
|
||||||
|
for i in range(len(self.single_blocks)):
|
||||||
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||||
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||||
|
if self.latent_input:
|
||||||
|
out_input = ()
|
||||||
|
for x in controlnet_double:
|
||||||
|
out_input += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_input = (controlnet_double * repeat)
|
||||||
|
|
||||||
|
out = {"input": out_input[:self.main_model_double]}
|
||||||
|
if len(controlnet_single) > 0:
|
||||||
|
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||||
|
out_output = ()
|
||||||
|
if self.latent_input:
|
||||||
|
for x in controlnet_single:
|
||||||
|
out_output += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_output = (controlnet_single * repeat)
|
||||||
|
out["output"] = out_output[:self.main_model_single]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
|
patch_size = 2
|
||||||
|
if self.latent_input:
|
||||||
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
else:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
||||||
@@ -6,6 +6,7 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
x_dtype = x.dtype
|
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||||
x = x.float()
|
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
|
||||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
@@ -170,15 +168,15 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += img_mod1.gate * self.img_attn.proj(img_attn)
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = txt.clip(-65504, 65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
@@ -233,7 +231,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x += mod.gate * output
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = x.clip(-65504, 65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class Flux(nn.Module):
|
|||||||
Transformer model for flow matching on sequences.
|
Transformer model for flow matching on sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
params = FluxParams(**kwargs)
|
params = FluxParams(**kwargs)
|
||||||
@@ -83,7 +83,8 @@ class Flux(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
if final_layer:
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@@ -94,6 +95,7 @@ class Flux(nn.Module):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
|
control=None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@@ -112,18 +114,34 @@ class Flux(nn.Module):
|
|||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
for block in self.double_blocks:
|
for i, block in enumerate(self.double_blocks):
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_i = control.get("input")
|
||||||
|
if i < len(control_i):
|
||||||
|
add = control_i[i]
|
||||||
|
if add is not None:
|
||||||
|
img += add
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
for block in self.single_blocks:
|
|
||||||
|
for i, block in enumerate(self.single_blocks):
|
||||||
img = block(img, vec=vec, pe=pe)
|
img = block(img, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_o = control.get("output")
|
||||||
|
if i < len(control_o):
|
||||||
|
add = control_o[i]
|
||||||
|
if add is not None:
|
||||||
|
img[:, txt.shape[1] :, ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
@@ -138,5 +156,5 @@ class Flux(nn.Module):
|
|||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||||
|
|||||||
@@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
if layer > self.depth // 2:
|
if layer > self.depth // 2:
|
||||||
if controls is not None:
|
if controls is not None:
|
||||||
skip = skips.pop() + controls.pop()
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
|||||||
@@ -358,7 +358,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
disabled_xformers = True
|
disabled_xformers = True
|
||||||
|
|
||||||
if disabled_xformers:
|
if disabled_xformers:
|
||||||
return attention_pytorch(q, k, v, heads, mask)
|
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
|
|||||||
@@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
"""
|
|
||||||
Apply the RMSNorm normalization to the input tensor.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
Forward pass through the RMSNorm layer.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The output tensor after applying RMSNorm.
|
|
||||||
"""
|
|
||||||
x = self._norm(x)
|
|
||||||
if self.learnable_scale:
|
|
||||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUFeedForward(nn.Module):
|
class SwiGLUFeedForward(nn.Module):
|
||||||
|
|||||||
236
comfy/lora.py
236
comfy/lora.py
@@ -16,8 +16,12 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.model_base
|
||||||
import logging
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
"mlp.fc1": "mlp_fc1",
|
"mlp.fc1": "mlp_fc1",
|
||||||
@@ -318,7 +322,235 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map[key_lora] = to
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
|
for p in patches:
|
||||||
|
strength = p[0]
|
||||||
|
v = p[1]
|
||||||
|
strength_model = p[2]
|
||||||
|
offset = p[3]
|
||||||
|
function = p[4]
|
||||||
|
if function is None:
|
||||||
|
function = lambda a: a
|
||||||
|
|
||||||
|
old_weight = None
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = weight
|
||||||
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
|
||||||
|
if strength_model != 1.0:
|
||||||
|
weight *= strength_model
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
|
if len(v) == 1:
|
||||||
|
patch_type = "diff"
|
||||||
|
elif len(v) == 2:
|
||||||
|
patch_type = v[0]
|
||||||
|
v = v[1]
|
||||||
|
|
||||||
|
if patch_type == "diff":
|
||||||
|
diff: torch.Tensor = v[0]
|
||||||
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||||
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||||
|
if do_pad_weight and diff.shape != weight.shape:
|
||||||
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||||
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||||
|
|
||||||
|
if strength != 0.0:
|
||||||
|
if diff.shape != weight.shape:
|
||||||
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
|
else:
|
||||||
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
|
elif patch_type == "lora": #lora/locon
|
||||||
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
|
dora_scale = v[4]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "lokr":
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "loha":
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "glora":
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / v[0].shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
else:
|
||||||
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
@@ -77,10 +95,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if self.manual_cast_dtype is not None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.manual_cast
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
if comfy.model_management.force_channels_last():
|
if comfy.model_management.force_channels_last():
|
||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
|
|||||||
@@ -472,9 +472,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||||
|
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||||
|
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||||
|
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
|||||||
@@ -44,9 +44,15 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
lowvram_available = True
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
torch_version = ""
|
||||||
|
try:
|
||||||
|
torch_version = torch.version.__version__
|
||||||
|
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
lowvram_available = True
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
logging.info("Using deterministic algorithms for pytorch")
|
logging.info("Using deterministic algorithms for pytorch")
|
||||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
@@ -66,10 +72,10 @@ if args.directml is not None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
if torch.xpu.is_available():
|
_ = torch.xpu.device_count()
|
||||||
xpu_available = True
|
xpu_available = torch.xpu.is_available()
|
||||||
except:
|
except:
|
||||||
pass
|
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@@ -189,7 +195,6 @@ VAE_DTYPES = [torch.float32]
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
torch_version = torch.version.__version__
|
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@@ -315,17 +320,15 @@ class LoadedModel:
|
|||||||
self.model_use_more_vram(use_more_vram)
|
self.model_use_more_vram(use_more_vram)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
|
||||||
else:
|
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
with torch.no_grad():
|
||||||
|
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
self.weights_loaded = True
|
self.weights_loaded = True
|
||||||
return self.real_model
|
return self.real_model
|
||||||
@@ -338,8 +341,9 @@ class LoadedModel:
|
|||||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||||
if memory_to_free is not None:
|
if memory_to_free is not None:
|
||||||
if memory_to_free < self.model.loaded_size():
|
if memory_to_free < self.model.loaded_size():
|
||||||
self.model.partially_unload(self.model.offload_device, memory_to_free)
|
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||||
return False
|
if freed >= memory_to_free:
|
||||||
|
return False
|
||||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||||
@@ -366,8 +370,21 @@ def offloaded_memory(loaded_models, device):
|
|||||||
offloaded_mem += m.model_offloaded_memory()
|
offloaded_mem += m.model_offloaded_memory()
|
||||||
return offloaded_mem
|
return offloaded_mem
|
||||||
|
|
||||||
|
WINDOWS = any(platform.win32_ver())
|
||||||
|
|
||||||
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
|
if WINDOWS:
|
||||||
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
|
|
||||||
|
if args.reserve_vram is not None:
|
||||||
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
|
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
||||||
|
|
||||||
|
def extra_reserved_memory():
|
||||||
|
return EXTRA_RESERVED_VRAM
|
||||||
|
|
||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||||
to_unload = []
|
to_unload = []
|
||||||
@@ -391,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
|||||||
if not force_unload:
|
if not force_unload:
|
||||||
if unload_weights_only and unload_weight == False:
|
if unload_weights_only and unload_weight == False:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
unload_weight = True
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||||
@@ -434,15 +453,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
extra_mem = max(inference_memory, memory_required) + 100 * 1024 * 1024
|
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
||||||
if minimum_memory_required is None:
|
if minimum_memory_required is None:
|
||||||
minimum_memory_required = extra_mem
|
minimum_memory_required = extra_mem
|
||||||
else:
|
else:
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required) + 100 * 1024 * 1024
|
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||||
|
|
||||||
models = set(models)
|
models = set(models)
|
||||||
|
|
||||||
@@ -513,7 +532,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
else:
|
else:
|
||||||
vram_set_state = vram_state
|
vram_set_state = vram_state
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||||
@@ -552,7 +571,9 @@ def loaded_models(only_currently_used=False):
|
|||||||
def cleanup_models(keep_clone_weights_loaded=False):
|
def cleanup_models(keep_clone_weights_loaded=False):
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
#TODO: very fragile function needs improvement
|
||||||
|
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||||
|
if num_refs <= 2:
|
||||||
if not keep_clone_weights_loaded:
|
if not keep_clone_weights_loaded:
|
||||||
to_delete = [i] + to_delete
|
to_delete = [i] + to_delete
|
||||||
#TODO: find a less fragile way to do this.
|
#TODO: find a less fragile way to do this.
|
||||||
@@ -659,6 +680,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||||
for dt in supported_dtypes:
|
for dt in supported_dtypes:
|
||||||
if dt == torch.float16 and fp16_supported:
|
if dt == torch.float16 and fp16_supported:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
@@ -684,6 +706,20 @@ def text_encoder_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||||
|
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
if is_device_mps(load_device):
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
mem_l = get_free_memory(load_device)
|
||||||
|
mem_o = get_free_memory(offload_device)
|
||||||
|
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||||
|
return load_device
|
||||||
|
else:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
def text_encoder_dtype(device=None):
|
def text_encoder_dtype(device=None):
|
||||||
if args.fp8_e4m3fn_text_enc:
|
if args.fp8_e4m3fn_text_enc:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
@@ -860,7 +896,8 @@ def pytorch_attention_flash_attention():
|
|||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -956,23 +993,23 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 8:
|
if props.major >= 8:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if props.major < 6:
|
if props.major < 6:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
fp16_works = False
|
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
||||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
|
||||||
#when the model doesn't actually fit on the card
|
|
||||||
#TODO: actually test if GP106 and others have the same type of behavior
|
|
||||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||||
for x in nvidia_10_series:
|
for x in nvidia_10_series:
|
||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
fp16_works = True
|
if WINDOWS or manual_cast:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False #weird linux behavior where fp32 is faster
|
||||||
|
|
||||||
if fp16_works or manual_cast:
|
if manual_cast:
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
return True
|
return True
|
||||||
@@ -1012,7 +1049,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 8:
|
if props.major >= 8:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -1025,6 +1062,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_fp8_compute(device=None):
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major >= 9:
|
||||||
|
return True
|
||||||
|
if props.major < 8:
|
||||||
|
return False
|
||||||
|
if props.minor < 9:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
|||||||
@@ -22,32 +22,26 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import collections
|
import collections
|
||||||
|
import math
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.lora
|
||||||
from comfy.types import UnetWrapperFunction
|
from comfy.types import UnetWrapperFunction
|
||||||
|
|
||||||
|
def string_to_seed(data):
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
crc = 0xFFFFFFFF
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
for byte in data:
|
||||||
lora_diff *= alpha
|
if isinstance(byte, str):
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
byte = ord(byte)
|
||||||
weight_norm = (
|
crc ^= byte
|
||||||
weight_calc.transpose(0, 1)
|
for _ in range(8):
|
||||||
.reshape(weight_calc.shape[1], -1)
|
if crc & 1:
|
||||||
.norm(dim=1, keepdim=True)
|
crc = (crc >> 1) ^ 0xEDB88320
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
else:
|
||||||
.transpose(0, 1)
|
crc >>= 1
|
||||||
)
|
return crc ^ 0xFFFFFFFF
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
|
||||||
weight[:] = weight_calc
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
@@ -90,12 +84,11 @@ def wipe_lowvram_weight(m):
|
|||||||
m.bias_function = None
|
m.bias_function = None
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, model_patcher):
|
def __init__(self, key, patches):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_patcher = model_patcher
|
self.patches = patches
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
@@ -327,46 +320,36 @@ class ModelPatcher:
|
|||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
|
|
||||||
def patch_model(self, device_to=None, patch_weights=True):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
for k in self.object_patches:
|
|
||||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
|
||||||
if k not in self.object_patches_backup:
|
|
||||||
self.object_patches_backup[k] = old
|
|
||||||
|
|
||||||
if patch_weights:
|
|
||||||
model_sd = self.model_state_dict()
|
|
||||||
for key in self.patches:
|
|
||||||
if key not in model_sd:
|
|
||||||
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.patch_weight_to_device(key, device_to)
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.model.device = device_to
|
|
||||||
self.model.model_loaded_weight_memory = self.model_size()
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m))
|
||||||
|
|
||||||
|
load_completely = []
|
||||||
|
loading.sort(reverse=True)
|
||||||
|
for x in loading:
|
||||||
|
n = x[1]
|
||||||
|
m = x[2]
|
||||||
|
module_mem = x[0]
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
|
||||||
module_mem = comfy.model_management.module_size(m)
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
if m.comfy_cast_weights:
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
@@ -377,13 +360,13 @@ class ModelPatcher:
|
|||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
@@ -394,202 +377,56 @@ class ModelPatcher:
|
|||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
mem_counter += comfy.model_management.module_size(m)
|
mem_counter += module_mem
|
||||||
param = list(m.parameters())
|
load_completely.append((module_mem, n, m))
|
||||||
if len(param) > 0:
|
|
||||||
weight = param[0]
|
|
||||||
if weight.device == device_to:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
|
load_completely.sort(reverse=True)
|
||||||
self.patch_weight_to_device(bias_key)
|
for x in load_completely:
|
||||||
m.to(device_to)
|
n = x[1]
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
m = x[2]
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
|
if m.comfy_patched_weights == True:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||||
|
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||||
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
m.comfy_patched_weights = True
|
||||||
|
|
||||||
|
for x in load_completely:
|
||||||
|
x[2].to(device_to)
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
|
if full_load:
|
||||||
|
self.model.to(device_to)
|
||||||
|
mem_counter = self.model_size()
|
||||||
|
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = mem_counter
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
|
|
||||||
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
|
for k in self.object_patches:
|
||||||
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||||
|
if k not in self.object_patches_backup:
|
||||||
|
self.object_patches_backup[k] = old
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
if lowvram_model_memory == 0:
|
||||||
self.patch_model(device_to, patch_weights=False)
|
full_load = True
|
||||||
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
else:
|
||||||
|
full_load = False
|
||||||
|
|
||||||
|
if load_weights:
|
||||||
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
|
||||||
for p in patches:
|
|
||||||
strength = p[0]
|
|
||||||
v = p[1]
|
|
||||||
strength_model = p[2]
|
|
||||||
offset = p[3]
|
|
||||||
function = p[4]
|
|
||||||
if function is None:
|
|
||||||
function = lambda a: a
|
|
||||||
|
|
||||||
old_weight = None
|
|
||||||
if offset is not None:
|
|
||||||
old_weight = weight
|
|
||||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
|
||||||
|
|
||||||
if strength_model != 1.0:
|
|
||||||
weight *= strength_model
|
|
||||||
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
|
||||||
|
|
||||||
if len(v) == 1:
|
|
||||||
patch_type = "diff"
|
|
||||||
elif len(v) == 2:
|
|
||||||
patch_type = v[0]
|
|
||||||
v = v[1]
|
|
||||||
|
|
||||||
if patch_type == "diff":
|
|
||||||
w1 = v[0]
|
|
||||||
if strength != 0.0:
|
|
||||||
if w1.shape != weight.shape:
|
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
||||||
else:
|
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
|
||||||
dora_scale = v[4]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
|
||||||
|
|
||||||
if old_weight is not None:
|
|
||||||
weight = old_weight
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
@@ -615,6 +452,10 @@ class ModelPatcher:
|
|||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
|
for m in self.model.modules():
|
||||||
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
|
del m.comfy_patched_weights
|
||||||
|
|
||||||
keys = list(self.object_patches_backup.keys())
|
keys = list(self.object_patches_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||||
@@ -624,40 +465,47 @@ class ModelPatcher:
|
|||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
|
unload_list = []
|
||||||
|
|
||||||
for n, m in list(self.model.named_modules())[::-1]:
|
for n, m in self.model.named_modules():
|
||||||
if memory_to_free < memory_freed:
|
|
||||||
break
|
|
||||||
|
|
||||||
shift_lowvram = False
|
shift_lowvram = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
weight_key = "{}.weight".format(n)
|
unload_list.append((module_mem, n, m))
|
||||||
bias_key = "{}.bias".format(n)
|
|
||||||
|
|
||||||
|
unload_list.sort()
|
||||||
|
for unload in unload_list:
|
||||||
|
if memory_to_free < memory_freed:
|
||||||
|
break
|
||||||
|
module_mem = unload[0]
|
||||||
|
n = unload[1]
|
||||||
|
m = unload[2]
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
if m.weight is not None and m.weight.device != device_to:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
for key in [weight_key, bias_key]:
|
for key in [weight_key, bias_key]:
|
||||||
bk = self.backup.get(key, None)
|
bk = self.backup.get(key, None)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
if bk.inplace_update:
|
if bk.inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
self.backup.pop(key)
|
self.backup.pop(key)
|
||||||
|
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
memory_freed += module_mem
|
m.comfy_patched_weights = False
|
||||||
logging.debug("freed {}".format(n))
|
memory_freed += module_mem
|
||||||
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
@@ -665,13 +513,20 @@ class ModelPatcher:
|
|||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0):
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
|
self.unpatch_model(unpatch_weights=False)
|
||||||
|
self.patch_model(load_weights=False)
|
||||||
|
full_load = False
|
||||||
if self.model.model_lowvram == False:
|
if self.model.model_lowvram == False:
|
||||||
return 0
|
return 0
|
||||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
pass #TODO: Full load
|
full_load = True
|
||||||
current_used = self.model.model_loaded_weight_memory
|
current_used = self.model.model_loaded_weight_memory
|
||||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
|
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
return self.model.device
|
return self.model.device
|
||||||
|
|
||||||
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
|
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||||
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|||||||
89
comfy/ops.py
89
comfy/ops.py
@@ -18,29 +18,42 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
return r
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
has_function = s.bias_function is not None
|
||||||
if s.bias_function is not None:
|
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
|
||||||
if s.weight_function is not None:
|
has_function = s.weight_function is not None
|
||||||
|
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
@@ -238,3 +251,59 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
class Embedding(disable_weight_init.Embedding):
|
class Embedding(disable_weight_init.Embedding):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_linear(self, input):
|
||||||
|
dtype = self.weight.dtype
|
||||||
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(input.shape) == 3:
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
|
w = w.t()
|
||||||
|
|
||||||
|
scale_weight = self.scale_weight
|
||||||
|
scale_input = self.scale_input
|
||||||
|
if scale_weight is None:
|
||||||
|
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
if scale_input is None:
|
||||||
|
scale_input = scale_weight
|
||||||
|
if scale_input is None:
|
||||||
|
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
else:
|
||||||
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
o = o[0]
|
||||||
|
|
||||||
|
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||||
|
return None
|
||||||
|
|
||||||
|
class fp8_ops(manual_cast):
|
||||||
|
class Linear(manual_cast.Linear):
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.scale_weight = None
|
||||||
|
self.scale_input = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
out = fp8_linear(self, input)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
||||||
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
|
return disable_weight_init
|
||||||
|
if args.fast:
|
||||||
|
if comfy.model_management.supports_fp8_compute(load_device):
|
||||||
|
return fp8_ops
|
||||||
|
return manual_cast
|
||||||
|
|||||||
84
comfy/sd.py
84
comfy/sd.py
@@ -24,6 +24,7 @@ import comfy.text_encoders.sa_t5
|
|||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
|
import comfy.text_encoders.long_clipl
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@@ -62,7 +63,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@@ -71,20 +72,29 @@ class CLIP:
|
|||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = offload_device
|
dtype = model_options.get("dtype", None)
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
if dtype is None:
|
||||||
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
|
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
|
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||||
|
params['model_options'] = model_options
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
for dt in self.cond_stage_model.dtypes:
|
for dt in self.cond_stage_model.dtypes:
|
||||||
if not model_management.supports_cast(load_device, dt):
|
if not model_management.supports_cast(load_device, dt):
|
||||||
load_device = offload_device
|
load_device = offload_device
|
||||||
|
if params['device'] != offload_device:
|
||||||
|
self.cond_stage_model.to(offload_device)
|
||||||
|
logging.warning("Had to shift TE back.")
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
if params['device'] == load_device:
|
||||||
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@@ -390,11 +400,14 @@ class CLIPType(Enum):
|
|||||||
HUNYUAN_DIT = 5
|
HUNYUAN_DIT = 5
|
||||||
FLUX = 6
|
FLUX = 6
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||||
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
|
clip_data = state_dicts
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -431,8 +444,13 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
if w is not None and w.shape[0] == 248:
|
||||||
|
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
|
||||||
|
else:
|
||||||
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
elif len(clip_data) == 2:
|
elif len(clip_data) == 2:
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||||
@@ -456,7 +474,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
parameters = 0
|
||||||
|
for c in clip_data:
|
||||||
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
|
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@@ -498,15 +520,19 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||||
sd_keys = sd.keys()
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
||||||
|
if out is None:
|
||||||
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
model = None
|
model = None
|
||||||
model_patcher = None
|
model_patcher = None
|
||||||
clip_target = None
|
|
||||||
|
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
@@ -515,13 +541,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
return None
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None:
|
if weight_dtype is not None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
|
unet_dtype = model_options.get("weight_dtype", None)
|
||||||
|
|
||||||
|
if unet_dtype is None:
|
||||||
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
@@ -545,7 +576,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
@@ -567,12 +599,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
logging.info("loaded straight to GPU")
|
logging.info("loaded straight to GPU")
|
||||||
model_management.load_model_gpu(model_patcher)
|
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||||
|
|
||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||||
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
@@ -614,6 +647,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
|||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
@@ -622,24 +656,36 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
|||||||
logging.info("left over keys in unet: {}".format(left_over))
|
logging.info("left over keys in unet: {}".format(left_over))
|
||||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
|
||||||
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
model = load_unet_state_dict(sd, dtype=dtype)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def load_unet(unet_path, dtype=None):
|
||||||
|
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|
||||||
|
def load_unet_state_dict(sd, dtype=None):
|
||||||
|
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||||
|
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||||
clip_sd = None
|
clip_sd = None
|
||||||
load_models = [model]
|
load_models = [model]
|
||||||
if clip is not None:
|
if clip is not None:
|
||||||
load_models.append(clip.load_model())
|
load_models.append(clip.load_model())
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.get_sd()
|
||||||
|
vae_sd = None
|
||||||
|
if vae is not None:
|
||||||
|
vae_sd = vae.get_sd()
|
||||||
|
|
||||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||||
for k in extra_keys:
|
for k in extra_keys:
|
||||||
sd[k] = extra_keys[k]
|
sd[k] = extra_keys[k]
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ class ClipTokenWeightEncoder:
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
"last",
|
"last",
|
||||||
"pooled",
|
"pooled",
|
||||||
@@ -84,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
|
||||||
@@ -94,7 +93,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
with open(textmodel_json_config) as f:
|
with open(textmodel_json_config) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
self.operations = comfy.ops.manual_cast
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
|
||||||
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
@@ -552,8 +555,12 @@ class SD1Tokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
class SD1CheckpointClipModel(SDClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module):
|
class SD1ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if name is not None:
|
if name is not None:
|
||||||
@@ -563,7 +570,7 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SDXLClipG(sd1_clip.SDClipModel):
|
class SDXLClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
layer_idx=-2
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
@@ -38,10 +38,10 @@ class SDXLTokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = set([dtype])
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
@@ -66,8 +66,8 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
return self.clip_l.load_sd(sd)
|
return self.clip_l.load_sd(sd)
|
||||||
|
|
||||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||||
@@ -79,14 +79,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||||
|
|
||||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
|
||||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options)
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class SDXL(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
memory_usage_factor = 0.7
|
memory_usage_factor = 0.8
|
||||||
|
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||||
@@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||||
|
dtype_t5 = None
|
||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
dtype_t5 = state_dict[t5_key].dtype
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from . import model_base
|
from . import model_base
|
||||||
from . import utils
|
from . import utils
|
||||||
@@ -30,6 +48,7 @@ class BASE:
|
|||||||
memory_usage_factor = 2.0
|
memory_usage_factor = 2.0
|
||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
|
custom_operations = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class PT5XlModel(sd1_clip.SDClipModel):
|
class PT5XlModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
||||||
|
|
||||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||||
|
|
||||||
class AuraT5Model(sd1_clip.SD1ClipModel):
|
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -35,11 +35,11 @@ class FluxTokenizer:
|
|||||||
|
|
||||||
|
|
||||||
class FluxClipModel(torch.nn.Module):
|
class FluxClipModel(torch.nn.Module):
|
||||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
|
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes = set([dtype, dtype_t5])
|
self.dtypes = set([dtype, dtype_t5])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
@@ -66,6 +66,6 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
def flux_clip(dtype_t5=None):
|
def flux_clip(dtype_t5=None):
|
||||||
class FluxClipModel_(FluxClipModel):
|
class FluxClipModel_(FluxClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
class HyditBertModel(sd1_clip.SDClipModel):
|
class HyditBertModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
|
|
||||||
class MT5XLModel(sd1_clip.SDClipModel):
|
class MT5XLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -50,10 +50,10 @@ class HyditTokenizer:
|
|||||||
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
||||||
|
|
||||||
class HyditModel(torch.nn.Module):
|
class HyditModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hydit_clip = HyditBertModel(dtype=dtype)
|
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
|
||||||
self.mt5xl = MT5XLModel(dtype=dtype)
|
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
|
|||||||
25
comfy/text_encoders/long_clipl.json
Normal file
25
comfy/text_encoders/long_clipl.json
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||||
|
"architectures": [
|
||||||
|
"CLIPTextModel"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"eos_token_id": 49407,
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"max_position_embeddings": 248,
|
||||||
|
"model_type": "clip_text_model",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"projection_dim": 768,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.24.0",
|
||||||
|
"vocab_size": 49408
|
||||||
|
}
|
||||||
19
comfy/text_encoders/long_clipl.py
Normal file
19
comfy/text_encoders/long_clipl.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import os
|
||||||
|
|
||||||
|
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
class LongClipModel_(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||||
|
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
|
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
|
||||||
|
|
||||||
|
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||||
@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class T5BaseModel(sd1_clip.SDClipModel):
|
class T5BaseModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||||
|
|
||||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||||
|
|||||||
@@ -2,13 +2,13 @@ from comfy import sd1_clip
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
layer_idx=-2
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
|
||||||
|
|
||||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||||
|
|
||||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||||
|
|||||||
@@ -8,14 +8,14 @@ import comfy.model_management
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||||
|
|
||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
@@ -38,24 +38,24 @@ class SD3Tokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_l = None
|
self.clip_l = None
|
||||||
|
|
||||||
if clip_g:
|
if clip_g:
|
||||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_g = None
|
self.clip_g = None
|
||||||
|
|
||||||
if t5:
|
if t5:
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes.add(dtype_t5)
|
self.dtypes.add(dtype_t5)
|
||||||
else:
|
else:
|
||||||
self.t5xxl = None
|
self.t5xxl = None
|
||||||
@@ -132,6 +132,6 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
||||||
|
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
|
|||||||
308
comfy_execution/caching.py
Normal file
308
comfy_execution/caching.py
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
import itertools
|
||||||
|
from typing import Sequence, Mapping
|
||||||
|
from comfy_execution.graph import DynamicPrompt
|
||||||
|
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
from comfy_execution.graph_utils import is_link
|
||||||
|
|
||||||
|
class CacheKeySet:
|
||||||
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
self.keys = {}
|
||||||
|
self.subcache_keys = {}
|
||||||
|
|
||||||
|
def add_keys(self, node_ids):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def all_node_ids(self):
|
||||||
|
return set(self.keys.keys())
|
||||||
|
|
||||||
|
def get_used_keys(self):
|
||||||
|
return self.keys.values()
|
||||||
|
|
||||||
|
def get_used_subcache_keys(self):
|
||||||
|
return self.subcache_keys.values()
|
||||||
|
|
||||||
|
def get_data_key(self, node_id):
|
||||||
|
return self.keys.get(node_id, None)
|
||||||
|
|
||||||
|
def get_subcache_key(self, node_id):
|
||||||
|
return self.subcache_keys.get(node_id, None)
|
||||||
|
|
||||||
|
class Unhashable:
|
||||||
|
def __init__(self):
|
||||||
|
self.value = float("NaN")
|
||||||
|
|
||||||
|
def to_hashable(obj):
|
||||||
|
# So that we don't infinitely recurse since frozenset and tuples
|
||||||
|
# are Sequences.
|
||||||
|
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||||
|
return obj
|
||||||
|
elif isinstance(obj, Mapping):
|
||||||
|
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||||
|
elif isinstance(obj, Sequence):
|
||||||
|
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||||
|
else:
|
||||||
|
# TODO - Support other objects like tensors?
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
class CacheKeySetID(CacheKeySet):
|
||||||
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.add_keys(node_ids)
|
||||||
|
|
||||||
|
def add_keys(self, node_ids):
|
||||||
|
for node_id in node_ids:
|
||||||
|
if node_id in self.keys:
|
||||||
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
|
node = self.dynprompt.get_node(node_id)
|
||||||
|
self.keys[node_id] = (node_id, node["class_type"])
|
||||||
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
|
class CacheKeySetInputSignature(CacheKeySet):
|
||||||
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.is_changed_cache = is_changed_cache
|
||||||
|
self.add_keys(node_ids)
|
||||||
|
|
||||||
|
def include_node_id_in_input(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def add_keys(self, node_ids):
|
||||||
|
for node_id in node_ids:
|
||||||
|
if node_id in self.keys:
|
||||||
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
|
node = self.dynprompt.get_node(node_id)
|
||||||
|
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||||
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
|
def get_node_signature(self, dynprompt, node_id):
|
||||||
|
signature = []
|
||||||
|
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||||
|
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||||
|
for ancestor_id in ancestors:
|
||||||
|
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||||
|
return to_hashable(signature)
|
||||||
|
|
||||||
|
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
# This node doesn't exist -- we can't cache it.
|
||||||
|
return [float("NaN")]
|
||||||
|
node = dynprompt.get_node(node_id)
|
||||||
|
class_type = node["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||||
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
||||||
|
signature.append(node_id)
|
||||||
|
inputs = node["inputs"]
|
||||||
|
for key in sorted(inputs.keys()):
|
||||||
|
if is_link(inputs[key]):
|
||||||
|
(ancestor_id, ancestor_socket) = inputs[key]
|
||||||
|
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||||
|
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||||
|
else:
|
||||||
|
signature.append((key, inputs[key]))
|
||||||
|
return signature
|
||||||
|
|
||||||
|
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||||
|
# deterministic based on which specific inputs the ancestor is connected by.
|
||||||
|
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||||
|
ancestors = []
|
||||||
|
order_mapping = {}
|
||||||
|
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||||
|
return ancestors, order_mapping
|
||||||
|
|
||||||
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
return
|
||||||
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
|
input_keys = sorted(inputs.keys())
|
||||||
|
for key in input_keys:
|
||||||
|
if is_link(inputs[key]):
|
||||||
|
ancestor_id = inputs[key][0]
|
||||||
|
if ancestor_id not in order_mapping:
|
||||||
|
ancestors.append(ancestor_id)
|
||||||
|
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||||
|
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||||
|
|
||||||
|
class BasicCache:
|
||||||
|
def __init__(self, key_class):
|
||||||
|
self.key_class = key_class
|
||||||
|
self.initialized = False
|
||||||
|
self.dynprompt: DynamicPrompt
|
||||||
|
self.cache_key_set: CacheKeySet
|
||||||
|
self.cache = {}
|
||||||
|
self.subcaches = {}
|
||||||
|
|
||||||
|
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||||
|
self.is_changed_cache = is_changed_cache
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def all_node_ids(self):
|
||||||
|
assert self.initialized
|
||||||
|
node_ids = self.cache_key_set.all_node_ids()
|
||||||
|
for subcache in self.subcaches.values():
|
||||||
|
node_ids = node_ids.union(subcache.all_node_ids())
|
||||||
|
return node_ids
|
||||||
|
|
||||||
|
def _clean_cache(self):
|
||||||
|
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||||
|
to_remove = []
|
||||||
|
for key in self.cache:
|
||||||
|
if key not in preserve_keys:
|
||||||
|
to_remove.append(key)
|
||||||
|
for key in to_remove:
|
||||||
|
del self.cache[key]
|
||||||
|
|
||||||
|
def _clean_subcaches(self):
|
||||||
|
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||||
|
|
||||||
|
to_remove = []
|
||||||
|
for key in self.subcaches:
|
||||||
|
if key not in preserve_subcaches:
|
||||||
|
to_remove.append(key)
|
||||||
|
for key in to_remove:
|
||||||
|
del self.subcaches[key]
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
assert self.initialized
|
||||||
|
self._clean_cache()
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def _set_immediate(self, node_id, value):
|
||||||
|
assert self.initialized
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
self.cache[cache_key] = value
|
||||||
|
|
||||||
|
def _get_immediate(self, node_id):
|
||||||
|
if not self.initialized:
|
||||||
|
return None
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
if cache_key in self.cache:
|
||||||
|
return self.cache[cache_key]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _ensure_subcache(self, node_id, children_ids):
|
||||||
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
|
subcache = self.subcaches.get(subcache_key, None)
|
||||||
|
if subcache is None:
|
||||||
|
subcache = BasicCache(self.key_class)
|
||||||
|
self.subcaches[subcache_key] = subcache
|
||||||
|
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||||
|
return subcache
|
||||||
|
|
||||||
|
def _get_subcache(self, node_id):
|
||||||
|
assert self.initialized
|
||||||
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
|
if subcache_key in self.subcaches:
|
||||||
|
return self.subcaches[subcache_key]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def recursive_debug_dump(self):
|
||||||
|
result = []
|
||||||
|
for key in self.cache:
|
||||||
|
result.append({"key": key, "value": self.cache[key]})
|
||||||
|
for key in self.subcaches:
|
||||||
|
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
||||||
|
return result
|
||||||
|
|
||||||
|
class HierarchicalCache(BasicCache):
|
||||||
|
def __init__(self, key_class):
|
||||||
|
super().__init__(key_class)
|
||||||
|
|
||||||
|
def _get_cache_for(self, node_id):
|
||||||
|
assert self.dynprompt is not None
|
||||||
|
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
||||||
|
if parent_id is None:
|
||||||
|
return self
|
||||||
|
|
||||||
|
hierarchy = []
|
||||||
|
while parent_id is not None:
|
||||||
|
hierarchy.append(parent_id)
|
||||||
|
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
||||||
|
|
||||||
|
cache = self
|
||||||
|
for parent_id in reversed(hierarchy):
|
||||||
|
cache = cache._get_subcache(parent_id)
|
||||||
|
if cache is None:
|
||||||
|
return None
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
if cache is None:
|
||||||
|
return None
|
||||||
|
return cache._get_immediate(node_id)
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
assert cache is not None
|
||||||
|
cache._set_immediate(node_id, value)
|
||||||
|
|
||||||
|
def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
assert cache is not None
|
||||||
|
return cache._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
|
class LRUCache(BasicCache):
|
||||||
|
def __init__(self, key_class, max_size=100):
|
||||||
|
super().__init__(key_class)
|
||||||
|
self.max_size = max_size
|
||||||
|
self.min_generation = 0
|
||||||
|
self.generation = 0
|
||||||
|
self.used_generation = {}
|
||||||
|
self.children = {}
|
||||||
|
|
||||||
|
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||||
|
self.generation += 1
|
||||||
|
for node_id in node_ids:
|
||||||
|
self._mark_used(node_id)
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
||||||
|
self.min_generation += 1
|
||||||
|
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
||||||
|
for key in to_remove:
|
||||||
|
del self.cache[key]
|
||||||
|
del self.used_generation[key]
|
||||||
|
if key in self.children:
|
||||||
|
del self.children[key]
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
self._mark_used(node_id)
|
||||||
|
return self._get_immediate(node_id)
|
||||||
|
|
||||||
|
def _mark_used(self, node_id):
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
if cache_key is not None:
|
||||||
|
self.used_generation[cache_key] = self.generation
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
self._mark_used(node_id)
|
||||||
|
return self._set_immediate(node_id, value)
|
||||||
|
|
||||||
|
def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
# Just uses subcaches for tracking 'live' nodes
|
||||||
|
super()._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
|
self.cache_key_set.add_keys(children_ids)
|
||||||
|
self._mark_used(node_id)
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
self.children[cache_key] = []
|
||||||
|
for child_id in children_ids:
|
||||||
|
self._mark_used(child_id)
|
||||||
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
|
return self
|
||||||
|
|
||||||
259
comfy_execution/graph.py
Normal file
259
comfy_execution/graph.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
import nodes
|
||||||
|
|
||||||
|
from comfy_execution.graph_utils import is_link
|
||||||
|
|
||||||
|
class DependencyCycleError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NodeInputError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NodeNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DynamicPrompt:
|
||||||
|
def __init__(self, original_prompt):
|
||||||
|
# The original prompt provided by the user
|
||||||
|
self.original_prompt = original_prompt
|
||||||
|
# Any extra pieces of the graph created during execution
|
||||||
|
self.ephemeral_prompt = {}
|
||||||
|
self.ephemeral_parents = {}
|
||||||
|
self.ephemeral_display = {}
|
||||||
|
|
||||||
|
def get_node(self, node_id):
|
||||||
|
if node_id in self.ephemeral_prompt:
|
||||||
|
return self.ephemeral_prompt[node_id]
|
||||||
|
if node_id in self.original_prompt:
|
||||||
|
return self.original_prompt[node_id]
|
||||||
|
raise NodeNotFoundError(f"Node {node_id} not found")
|
||||||
|
|
||||||
|
def has_node(self, node_id):
|
||||||
|
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
|
||||||
|
|
||||||
|
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
||||||
|
self.ephemeral_prompt[node_id] = node_info
|
||||||
|
self.ephemeral_parents[node_id] = parent_id
|
||||||
|
self.ephemeral_display[node_id] = display_id
|
||||||
|
|
||||||
|
def get_real_node_id(self, node_id):
|
||||||
|
while node_id in self.ephemeral_parents:
|
||||||
|
node_id = self.ephemeral_parents[node_id]
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
def get_parent_node_id(self, node_id):
|
||||||
|
return self.ephemeral_parents.get(node_id, None)
|
||||||
|
|
||||||
|
def get_display_node_id(self, node_id):
|
||||||
|
while node_id in self.ephemeral_display:
|
||||||
|
node_id = self.ephemeral_display[node_id]
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
def all_node_ids(self):
|
||||||
|
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||||
|
|
||||||
|
def get_original_prompt(self):
|
||||||
|
return self.original_prompt
|
||||||
|
|
||||||
|
def get_input_info(class_def, input_name):
|
||||||
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
|
input_info = None
|
||||||
|
input_category = None
|
||||||
|
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||||
|
input_category = "required"
|
||||||
|
input_info = valid_inputs["required"][input_name]
|
||||||
|
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
||||||
|
input_category = "optional"
|
||||||
|
input_info = valid_inputs["optional"][input_name]
|
||||||
|
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
||||||
|
input_category = "hidden"
|
||||||
|
input_info = valid_inputs["hidden"][input_name]
|
||||||
|
if input_info is None:
|
||||||
|
return None, None, None
|
||||||
|
input_type = input_info[0]
|
||||||
|
if len(input_info) > 1:
|
||||||
|
extra_info = input_info[1]
|
||||||
|
else:
|
||||||
|
extra_info = {}
|
||||||
|
return input_type, input_category, extra_info
|
||||||
|
|
||||||
|
class TopologicalSort:
|
||||||
|
def __init__(self, dynprompt):
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.pendingNodes = {}
|
||||||
|
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||||
|
self.blocking = {} # Which nodes are blocked by this node
|
||||||
|
|
||||||
|
def get_input_info(self, unique_id, input_name):
|
||||||
|
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
return get_input_info(class_def, input_name)
|
||||||
|
|
||||||
|
def make_input_strong_link(self, to_node_id, to_input):
|
||||||
|
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||||
|
if to_input not in inputs:
|
||||||
|
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
|
||||||
|
value = inputs[to_input]
|
||||||
|
if not is_link(value):
|
||||||
|
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
|
||||||
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
|
self.add_node(from_node_id)
|
||||||
|
if to_node_id not in self.blocking[from_node_id]:
|
||||||
|
self.blocking[from_node_id][to_node_id] = {}
|
||||||
|
self.blockCount[to_node_id] += 1
|
||||||
|
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||||
|
|
||||||
|
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
||||||
|
if unique_id in self.pendingNodes:
|
||||||
|
return
|
||||||
|
self.pendingNodes[unique_id] = True
|
||||||
|
self.blockCount[unique_id] = 0
|
||||||
|
self.blocking[unique_id] = {}
|
||||||
|
|
||||||
|
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||||
|
for input_name in inputs:
|
||||||
|
value = inputs[input_name]
|
||||||
|
if is_link(value):
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
|
continue
|
||||||
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||||
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
|
if include_lazy or not is_lazy:
|
||||||
|
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||||
|
|
||||||
|
def get_ready_nodes(self):
|
||||||
|
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||||
|
|
||||||
|
def pop_node(self, unique_id):
|
||||||
|
del self.pendingNodes[unique_id]
|
||||||
|
for blocked_node_id in self.blocking[unique_id]:
|
||||||
|
self.blockCount[blocked_node_id] -= 1
|
||||||
|
del self.blocking[unique_id]
|
||||||
|
|
||||||
|
def is_empty(self):
|
||||||
|
return len(self.pendingNodes) == 0
|
||||||
|
|
||||||
|
class ExecutionList(TopologicalSort):
|
||||||
|
"""
|
||||||
|
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||||
|
it can still be returned to the graph after having further dependencies added.
|
||||||
|
"""
|
||||||
|
def __init__(self, dynprompt, output_cache):
|
||||||
|
super().__init__(dynprompt)
|
||||||
|
self.output_cache = output_cache
|
||||||
|
self.staged_node_id = None
|
||||||
|
|
||||||
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
|
if self.output_cache.get(from_node_id) is not None:
|
||||||
|
# Nothing to do
|
||||||
|
return
|
||||||
|
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
|
||||||
|
def stage_node_execution(self):
|
||||||
|
assert self.staged_node_id is None
|
||||||
|
if self.is_empty():
|
||||||
|
return None, None, None
|
||||||
|
available = self.get_ready_nodes()
|
||||||
|
if len(available) == 0:
|
||||||
|
cycled_nodes = self.get_nodes_in_cycle()
|
||||||
|
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||||
|
# we will 'blame' the first node in the cycle that is not a static node.
|
||||||
|
blamed_node = cycled_nodes[0]
|
||||||
|
for node_id in cycled_nodes:
|
||||||
|
display_node_id = self.dynprompt.get_display_node_id(node_id)
|
||||||
|
if display_node_id != node_id:
|
||||||
|
blamed_node = display_node_id
|
||||||
|
break
|
||||||
|
ex = DependencyCycleError("Dependency cycle detected")
|
||||||
|
error_details = {
|
||||||
|
"node_id": blamed_node,
|
||||||
|
"exception_message": str(ex),
|
||||||
|
"exception_type": "graph.DependencyCycleError",
|
||||||
|
"traceback": [],
|
||||||
|
"current_inputs": []
|
||||||
|
}
|
||||||
|
return None, error_details, ex
|
||||||
|
|
||||||
|
self.staged_node_id = self.ux_friendly_pick_node(available)
|
||||||
|
return self.staged_node_id, None, None
|
||||||
|
|
||||||
|
def ux_friendly_pick_node(self, node_list):
|
||||||
|
# If an output node is available, do that first.
|
||||||
|
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||||
|
# for a PreviewImage to display a result as soon as it can
|
||||||
|
# Some other heuristics could probably be used here to improve the UX further.
|
||||||
|
def is_output(node_id):
|
||||||
|
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
for node_id in node_list:
|
||||||
|
if is_output(node_id):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
#This should handle the VAEDecode -> preview case
|
||||||
|
for node_id in node_list:
|
||||||
|
for blocked_node_id in self.blocking[node_id]:
|
||||||
|
if is_output(blocked_node_id):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
#This should handle the VAELoader -> VAEDecode -> preview case
|
||||||
|
for node_id in node_list:
|
||||||
|
for blocked_node_id in self.blocking[node_id]:
|
||||||
|
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
||||||
|
if is_output(blocked_node_id1):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
#TODO: this function should be improved
|
||||||
|
return node_list[0]
|
||||||
|
|
||||||
|
def unstage_node_execution(self):
|
||||||
|
assert self.staged_node_id is not None
|
||||||
|
self.staged_node_id = None
|
||||||
|
|
||||||
|
def complete_node_execution(self):
|
||||||
|
node_id = self.staged_node_id
|
||||||
|
self.pop_node(node_id)
|
||||||
|
self.staged_node_id = None
|
||||||
|
|
||||||
|
def get_nodes_in_cycle(self):
|
||||||
|
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
||||||
|
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
||||||
|
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
||||||
|
blocked_by = { node_id: {} for node_id in self.pendingNodes }
|
||||||
|
for from_node_id in self.blocking:
|
||||||
|
for to_node_id in self.blocking[from_node_id]:
|
||||||
|
if True in self.blocking[from_node_id][to_node_id].values():
|
||||||
|
blocked_by[to_node_id][from_node_id] = True
|
||||||
|
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||||
|
while len(to_remove) > 0:
|
||||||
|
for node_id in to_remove:
|
||||||
|
for to_node_id in blocked_by:
|
||||||
|
if node_id in blocked_by[to_node_id]:
|
||||||
|
del blocked_by[to_node_id][node_id]
|
||||||
|
del blocked_by[node_id]
|
||||||
|
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||||
|
return list(blocked_by.keys())
|
||||||
|
|
||||||
|
class ExecutionBlocker:
|
||||||
|
"""
|
||||||
|
Return this from a node and any users will be blocked with the given error message.
|
||||||
|
If the message is None, execution will be blocked silently instead.
|
||||||
|
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||||
|
possible, a lazy input will be more efficient and have a better user experience.
|
||||||
|
This functionality is useful in two cases:
|
||||||
|
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||||
|
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||||
|
lazy evaluation to let it conditionally disable itself.)
|
||||||
|
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||||
|
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||||
|
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||||
|
"""
|
||||||
|
def __init__(self, message):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
139
comfy_execution/graph_utils.py
Normal file
139
comfy_execution/graph_utils.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
def is_link(obj):
|
||||||
|
if not isinstance(obj, list):
|
||||||
|
return False
|
||||||
|
if len(obj) != 2:
|
||||||
|
return False
|
||||||
|
if not isinstance(obj[0], str):
|
||||||
|
return False
|
||||||
|
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
||||||
|
class GraphBuilder:
|
||||||
|
_default_prefix_root = ""
|
||||||
|
_default_prefix_call_index = 0
|
||||||
|
_default_prefix_graph_index = 0
|
||||||
|
|
||||||
|
def __init__(self, prefix = None):
|
||||||
|
if prefix is None:
|
||||||
|
self.prefix = GraphBuilder.alloc_prefix()
|
||||||
|
else:
|
||||||
|
self.prefix = prefix
|
||||||
|
self.nodes = {}
|
||||||
|
self.id_gen = 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
|
||||||
|
cls._default_prefix_root = prefix_root
|
||||||
|
cls._default_prefix_call_index = call_index
|
||||||
|
cls._default_prefix_graph_index = graph_index
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
|
||||||
|
if root is None:
|
||||||
|
root = GraphBuilder._default_prefix_root
|
||||||
|
if call_index is None:
|
||||||
|
call_index = GraphBuilder._default_prefix_call_index
|
||||||
|
if graph_index is None:
|
||||||
|
graph_index = GraphBuilder._default_prefix_graph_index
|
||||||
|
result = f"{root}.{call_index}.{graph_index}."
|
||||||
|
GraphBuilder._default_prefix_graph_index += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
def node(self, class_type, id=None, **kwargs):
|
||||||
|
if id is None:
|
||||||
|
id = str(self.id_gen)
|
||||||
|
self.id_gen += 1
|
||||||
|
id = self.prefix + id
|
||||||
|
if id in self.nodes:
|
||||||
|
return self.nodes[id]
|
||||||
|
|
||||||
|
node = Node(id, class_type, kwargs)
|
||||||
|
self.nodes[id] = node
|
||||||
|
return node
|
||||||
|
|
||||||
|
def lookup_node(self, id):
|
||||||
|
id = self.prefix + id
|
||||||
|
return self.nodes.get(id)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
output = {}
|
||||||
|
for node_id, node in self.nodes.items():
|
||||||
|
output[node_id] = node.serialize()
|
||||||
|
return output
|
||||||
|
|
||||||
|
def replace_node_output(self, node_id, index, new_value):
|
||||||
|
node_id = self.prefix + node_id
|
||||||
|
to_remove = []
|
||||||
|
for node in self.nodes.values():
|
||||||
|
for key, value in node.inputs.items():
|
||||||
|
if is_link(value) and value[0] == node_id and value[1] == index:
|
||||||
|
if new_value is None:
|
||||||
|
to_remove.append((node, key))
|
||||||
|
else:
|
||||||
|
node.inputs[key] = new_value
|
||||||
|
for node, key in to_remove:
|
||||||
|
del node.inputs[key]
|
||||||
|
|
||||||
|
def remove_node(self, id):
|
||||||
|
id = self.prefix + id
|
||||||
|
del self.nodes[id]
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
def __init__(self, id, class_type, inputs):
|
||||||
|
self.id = id
|
||||||
|
self.class_type = class_type
|
||||||
|
self.inputs = inputs
|
||||||
|
self.override_display_id = None
|
||||||
|
|
||||||
|
def out(self, index):
|
||||||
|
return [self.id, index]
|
||||||
|
|
||||||
|
def set_input(self, key, value):
|
||||||
|
if value is None:
|
||||||
|
if key in self.inputs:
|
||||||
|
del self.inputs[key]
|
||||||
|
else:
|
||||||
|
self.inputs[key] = value
|
||||||
|
|
||||||
|
def get_input(self, key):
|
||||||
|
return self.inputs.get(key)
|
||||||
|
|
||||||
|
def set_override_display_id(self, override_display_id):
|
||||||
|
self.override_display_id = override_display_id
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
serialized = {
|
||||||
|
"class_type": self.class_type,
|
||||||
|
"inputs": self.inputs
|
||||||
|
}
|
||||||
|
if self.override_display_id is not None:
|
||||||
|
serialized["override_display_id"] = self.override_display_id
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
def add_graph_prefix(graph, outputs, prefix):
|
||||||
|
# Change the node IDs and any internal links
|
||||||
|
new_graph = {}
|
||||||
|
for node_id, node_info in graph.items():
|
||||||
|
# Make sure the added nodes have unique IDs
|
||||||
|
new_node_id = prefix + node_id
|
||||||
|
new_node = { "class_type": node_info["class_type"], "inputs": {} }
|
||||||
|
for input_name, input_value in node_info.get("inputs", {}).items():
|
||||||
|
if is_link(input_value):
|
||||||
|
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
|
||||||
|
else:
|
||||||
|
new_node["inputs"][input_name] = input_value
|
||||||
|
new_graph[new_node_id] = new_node
|
||||||
|
|
||||||
|
# Change the node IDs in the outputs
|
||||||
|
new_outputs = []
|
||||||
|
for n in range(len(outputs)):
|
||||||
|
output = outputs[n]
|
||||||
|
if is_link(output):
|
||||||
|
new_outputs.append([prefix + output[0], output[1]])
|
||||||
|
else:
|
||||||
|
new_outputs.append(output)
|
||||||
|
|
||||||
|
return new_graph, tuple(new_outputs)
|
||||||
|
|
||||||
@@ -333,6 +333,25 @@ class VAESave:
|
|||||||
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
class ModelSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
|
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||||
|
return {}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSimple": ModelMergeSimple,
|
"ModelMergeSimple": ModelMergeSimple,
|
||||||
"ModelMergeBlocks": ModelMergeBlocks,
|
"ModelMergeBlocks": ModelMergeBlocks,
|
||||||
@@ -344,4 +363,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPMergeAdd": CLIPAdd,
|
"CLIPMergeAdd": CLIPAdd,
|
||||||
"CLIPSave": CLIPSave,
|
"CLIPSave": CLIPSave,
|
||||||
"VAESave": VAESave,
|
"VAESave": VAESave,
|
||||||
|
"ModelSave": ModelSave,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"CheckpointSave": "Save Checkpoint",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ class Example:
|
|||||||
|
|
||||||
Class methods
|
Class methods
|
||||||
-------------
|
-------------
|
||||||
INPUT_TYPES (dict):
|
INPUT_TYPES (dict):
|
||||||
Tell the main program input parameters of nodes.
|
Tell the main program input parameters of nodes.
|
||||||
IS_CHANGED:
|
IS_CHANGED:
|
||||||
optional method to control when the node is re executed.
|
optional method to control when the node is re executed.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
RETURN_TYPES (`tuple`):
|
RETURN_TYPES (`tuple`):
|
||||||
The type of each element in the output tuple.
|
The type of each element in the output tuple.
|
||||||
RETURN_NAMES (`tuple`):
|
RETURN_NAMES (`tuple`):
|
||||||
Optional: The name of each output in the output tuple.
|
Optional: The name of each output in the output tuple.
|
||||||
@@ -23,13 +23,19 @@ class Example:
|
|||||||
Assumed to be False if not present.
|
Assumed to be False if not present.
|
||||||
CATEGORY (`str`):
|
CATEGORY (`str`):
|
||||||
The category the node should appear in the UI.
|
The category the node should appear in the UI.
|
||||||
|
DEPRECATED (`bool`):
|
||||||
|
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
|
||||||
|
functional in existing workflows that use them.
|
||||||
|
EXPERIMENTAL (`bool`):
|
||||||
|
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
|
||||||
|
significant changes or removal in future versions. Use with caution in production workflows.
|
||||||
execute(s) -> tuple || None:
|
execute(s) -> tuple || None:
|
||||||
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
|
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
|
||||||
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
|
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
"""
|
"""
|
||||||
@@ -54,7 +60,8 @@ class Example:
|
|||||||
"min": 0, #Minimum value
|
"min": 0, #Minimum value
|
||||||
"max": 4096, #Maximum value
|
"max": 4096, #Maximum value
|
||||||
"step": 64, #Slider's step
|
"step": 64, #Slider's step
|
||||||
"display": "number" # Cosmetic only: display as "number" or "slider"
|
"display": "number", # Cosmetic only: display as "number" or "slider"
|
||||||
|
"lazy": True # Will only be evaluated if check_lazy_status requires it
|
||||||
}),
|
}),
|
||||||
"float_field": ("FLOAT", {
|
"float_field": ("FLOAT", {
|
||||||
"default": 1.0,
|
"default": 1.0,
|
||||||
@@ -62,11 +69,14 @@ class Example:
|
|||||||
"max": 10.0,
|
"max": 10.0,
|
||||||
"step": 0.01,
|
"step": 0.01,
|
||||||
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
||||||
"display": "number"}),
|
"display": "number",
|
||||||
|
"lazy": True
|
||||||
|
}),
|
||||||
"print_to_screen": (["enable", "disable"],),
|
"print_to_screen": (["enable", "disable"],),
|
||||||
"string_field": ("STRING", {
|
"string_field": ("STRING", {
|
||||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||||
"default": "Hello World!"
|
"default": "Hello World!",
|
||||||
|
"lazy": True
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -80,6 +90,23 @@ class Example:
|
|||||||
|
|
||||||
CATEGORY = "Example"
|
CATEGORY = "Example"
|
||||||
|
|
||||||
|
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
|
||||||
|
"""
|
||||||
|
Return a list of input names that need to be evaluated.
|
||||||
|
|
||||||
|
This function will be called if there are any lazy inputs which have not yet been
|
||||||
|
evaluated. As long as you return at least one field which has not yet been evaluated
|
||||||
|
(and more exist), this function will be called again once the value of the requested
|
||||||
|
field is available.
|
||||||
|
|
||||||
|
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
|
||||||
|
inputs will have the value None.
|
||||||
|
"""
|
||||||
|
if print_to_screen == "enable":
|
||||||
|
return ["int_field", "float_field", "string_field"]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
def test(self, image, string_field, int_field, float_field, print_to_screen):
|
def test(self, image, string_field, int_field, float_field, print_to_screen):
|
||||||
if print_to_screen == "enable":
|
if print_to_screen == "enable":
|
||||||
print(f"""Your input contains:
|
print(f"""Your input contains:
|
||||||
|
|||||||
648
execution.py
648
execution.py
@@ -5,6 +5,7 @@ import threading
|
|||||||
import heapq
|
import heapq
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from enum import Enum
|
||||||
import inspect
|
import inspect
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
@@ -12,102 +13,219 @@ import torch
|
|||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
class ExecutionResult(Enum):
|
||||||
|
SUCCESS = 0
|
||||||
|
FAILURE = 1
|
||||||
|
PENDING = 2
|
||||||
|
|
||||||
|
class DuplicateNodeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class IsChangedCache:
|
||||||
|
def __init__(self, dynprompt, outputs_cache):
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.outputs_cache = outputs_cache
|
||||||
|
self.is_changed = {}
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
if node_id in self.is_changed:
|
||||||
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
node = self.dynprompt.get_node(node_id)
|
||||||
|
class_type = node["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
if not hasattr(class_def, "IS_CHANGED"):
|
||||||
|
self.is_changed[node_id] = False
|
||||||
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
if "is_changed" in node:
|
||||||
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
|
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
|
try:
|
||||||
|
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("WARNING: {}".format(e))
|
||||||
|
node["is_changed"] = float("NaN")
|
||||||
|
finally:
|
||||||
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
class CacheSet:
|
||||||
|
def __init__(self, lru_size=None):
|
||||||
|
if lru_size is None or lru_size == 0:
|
||||||
|
self.init_classic_cache()
|
||||||
|
else:
|
||||||
|
self.init_lru_cache(lru_size)
|
||||||
|
self.all = [self.outputs, self.ui, self.objects]
|
||||||
|
|
||||||
|
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||||
|
# blowing away the cache every time
|
||||||
|
def init_lru_cache(self, cache_size):
|
||||||
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
# Performs like the old cache -- dump data ASAP
|
||||||
|
def init_classic_cache(self):
|
||||||
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
|
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
def recursive_debug_dump(self):
|
||||||
|
result = {
|
||||||
|
"outputs": self.outputs.recursive_debug_dump(),
|
||||||
|
"ui": self.ui.recursive_debug_dump(),
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
if isinstance(input_data, list):
|
input_type, input_category, input_info = get_input_info(class_def, x)
|
||||||
|
def mark_missing():
|
||||||
|
missing_keys[x] = True
|
||||||
|
input_data_all[x] = (None,)
|
||||||
|
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs:
|
if outputs is None:
|
||||||
input_data_all[x] = (None,)
|
mark_missing()
|
||||||
|
continue # This might be a lazily-evaluated input
|
||||||
|
cached_output = outputs.get(input_unique_id)
|
||||||
|
if cached_output is None:
|
||||||
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
obj = outputs[input_unique_id][output_index]
|
if output_index >= len(cached_output):
|
||||||
|
mark_missing()
|
||||||
|
continue
|
||||||
|
obj = cached_output[output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
else:
|
elif input_category is not None:
|
||||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
input_data_all[x] = [input_data]
|
||||||
input_data_all[x] = [input_data]
|
|
||||||
|
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
h = valid_inputs["hidden"]
|
h = valid_inputs["hidden"]
|
||||||
for x in h:
|
for x in h:
|
||||||
if h[x] == "PROMPT":
|
if h[x] == "PROMPT":
|
||||||
input_data_all[x] = [prompt]
|
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||||
|
if h[x] == "DYNPROMPT":
|
||||||
|
input_data_all[x] = [dynprompt]
|
||||||
if h[x] == "EXTRA_PNGINFO":
|
if h[x] == "EXTRA_PNGINFO":
|
||||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all, missing_keys
|
||||||
|
|
||||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
map_node_over_list = None #Don't hook this please
|
||||||
|
|
||||||
|
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = False
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||||
if hasattr(obj, "INPUT_IS_LIST"):
|
|
||||||
input_is_list = obj.INPUT_IS_LIST
|
|
||||||
|
|
||||||
if len(input_data_all) == 0:
|
if len(input_data_all) == 0:
|
||||||
max_len_input = 0
|
max_len_input = 0
|
||||||
else:
|
else:
|
||||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
max_len_input = max(len(x) for x in input_data_all.values())
|
||||||
|
|
||||||
# get a slice of inputs, repeat last input when list isn't long enough
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
def slice_dict(d, i):
|
def slice_dict(d, i):
|
||||||
d_new = dict()
|
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||||
for k,v in d.items():
|
|
||||||
d_new[k] = v[i if len(v) > i else -1]
|
|
||||||
return d_new
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
def process_inputs(inputs, index=None):
|
||||||
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
execution_block = None
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if isinstance(v, ExecutionBlocker):
|
||||||
|
execution_block = execution_block_cb(v) if execution_block_cb else v
|
||||||
|
break
|
||||||
|
if execution_block is None:
|
||||||
|
if pre_execute_cb is not None and index is not None:
|
||||||
|
pre_execute_cb(index)
|
||||||
|
results.append(getattr(obj, func)(**inputs))
|
||||||
|
else:
|
||||||
|
results.append(execution_block)
|
||||||
|
|
||||||
if input_is_list:
|
if input_is_list:
|
||||||
if allow_interrupt:
|
process_inputs(input_data_all, 0)
|
||||||
nodes.before_node_execution()
|
|
||||||
results.append(getattr(obj, func)(**input_data_all))
|
|
||||||
elif max_len_input == 0:
|
elif max_len_input == 0:
|
||||||
if allow_interrupt:
|
process_inputs({})
|
||||||
nodes.before_node_execution()
|
else:
|
||||||
results.append(getattr(obj, func)())
|
|
||||||
else:
|
|
||||||
for i in range(max_len_input):
|
for i in range(max_len_input):
|
||||||
if allow_interrupt:
|
input_dict = slice_dict(input_data_all, i)
|
||||||
nodes.before_node_execution()
|
process_inputs(input_dict, i)
|
||||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_output_data(obj, input_data_all):
|
def merge_result_data(results, obj):
|
||||||
|
# check which outputs need concatenating
|
||||||
|
output = []
|
||||||
|
output_is_list = [False] * len(results[0])
|
||||||
|
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||||
|
output_is_list = obj.OUTPUT_IS_LIST
|
||||||
|
|
||||||
|
# merge node execution results
|
||||||
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
|
if is_list:
|
||||||
|
output.append([x for o in results for x in o[i]])
|
||||||
|
else:
|
||||||
|
output.append([o[i] for o in results])
|
||||||
|
return output
|
||||||
|
|
||||||
|
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
uis = []
|
uis = []
|
||||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
subgraph_results = []
|
||||||
|
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||||
for r in return_values:
|
has_subgraph = False
|
||||||
|
for i in range(len(return_values)):
|
||||||
|
r = return_values[i]
|
||||||
if isinstance(r, dict):
|
if isinstance(r, dict):
|
||||||
if 'ui' in r:
|
if 'ui' in r:
|
||||||
uis.append(r['ui'])
|
uis.append(r['ui'])
|
||||||
if 'result' in r:
|
if 'expand' in r:
|
||||||
results.append(r['result'])
|
# Perform an expansion, but do not append results
|
||||||
|
has_subgraph = True
|
||||||
|
new_graph = r['expand']
|
||||||
|
result = r.get("result", None)
|
||||||
|
if isinstance(result, ExecutionBlocker):
|
||||||
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||||
|
subgraph_results.append((new_graph, result))
|
||||||
|
elif 'result' in r:
|
||||||
|
result = r.get("result", None)
|
||||||
|
if isinstance(result, ExecutionBlocker):
|
||||||
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||||
|
results.append(result)
|
||||||
|
subgraph_results.append((None, result))
|
||||||
else:
|
else:
|
||||||
|
if isinstance(r, ExecutionBlocker):
|
||||||
|
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||||
results.append(r)
|
results.append(r)
|
||||||
|
subgraph_results.append((None, r))
|
||||||
|
|
||||||
output = []
|
if has_subgraph:
|
||||||
if len(results) > 0:
|
output = subgraph_results
|
||||||
# check which outputs need concatenating
|
elif len(results) > 0:
|
||||||
output_is_list = [False] * len(results[0])
|
output = merge_result_data(results, obj)
|
||||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
else:
|
||||||
output_is_list = obj.OUTPUT_IS_LIST
|
output = []
|
||||||
|
|
||||||
# merge node execution results
|
|
||||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
|
||||||
if is_list:
|
|
||||||
output.append([x for o in results for x in o[i]])
|
|
||||||
else:
|
|
||||||
output.append([o[i] for o in results])
|
|
||||||
|
|
||||||
ui = dict()
|
ui = dict()
|
||||||
if len(uis) > 0:
|
if len(uis) > 0:
|
||||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||||
return output, ui
|
return output, ui, has_subgraph
|
||||||
|
|
||||||
def format_value(x):
|
def format_value(x):
|
||||||
if x is None:
|
if x is None:
|
||||||
@@ -117,53 +235,145 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
class_type = prompt[unique_id]['class_type']
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
|
parent_node_id = dynprompt.get_parent_node_id(unique_id)
|
||||||
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if unique_id in outputs:
|
if caches.outputs.get(unique_id) is not None:
|
||||||
return (True, None, None)
|
if server.client_id is not None:
|
||||||
|
cached_output = caches.ui.get(unique_id) or {}
|
||||||
for x in inputs:
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
input_data = inputs[x]
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
if isinstance(input_data, list):
|
|
||||||
input_unique_id = input_data[0]
|
|
||||||
output_index = input_data[1]
|
|
||||||
if input_unique_id not in outputs:
|
|
||||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
|
|
||||||
if result[0] is not True:
|
|
||||||
# Another node failed further upstream
|
|
||||||
return result
|
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all = None
|
||||||
try:
|
try:
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
if unique_id in pending_subgraph_results:
|
||||||
if server.client_id is not None:
|
cached_results = pending_subgraph_results[unique_id]
|
||||||
server.last_node_id = unique_id
|
resolved_outputs = []
|
||||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
for is_subgraph, result in cached_results:
|
||||||
|
if not is_subgraph:
|
||||||
|
resolved_outputs.append(result)
|
||||||
|
else:
|
||||||
|
resolved_output = []
|
||||||
|
for r in result:
|
||||||
|
if is_link(r):
|
||||||
|
source_node, source_output = r[0], r[1]
|
||||||
|
node_output = caches.outputs.get(source_node)[source_output]
|
||||||
|
for o in node_output:
|
||||||
|
resolved_output.append(o)
|
||||||
|
|
||||||
obj = object_storage.get((unique_id, class_type), None)
|
else:
|
||||||
if obj is None:
|
resolved_output.append(r)
|
||||||
obj = class_def()
|
resolved_outputs.append(tuple(resolved_output))
|
||||||
object_storage[(unique_id, class_type)] = obj
|
output_data = merge_result_data(resolved_outputs, class_def)
|
||||||
|
output_ui = []
|
||||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
has_subgraph = False
|
||||||
outputs[unique_id] = output_data
|
else:
|
||||||
if len(output_ui) > 0:
|
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||||
outputs_ui[unique_id] = output_ui
|
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.last_node_id = display_node_id
|
||||||
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
|
obj = caches.objects.get(unique_id)
|
||||||
|
if obj is None:
|
||||||
|
obj = class_def()
|
||||||
|
caches.objects.set(unique_id, obj)
|
||||||
|
|
||||||
|
if hasattr(obj, "check_lazy_status"):
|
||||||
|
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||||
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
|
x not in input_data_all or x in missing_keys
|
||||||
|
)]
|
||||||
|
if len(required_inputs) > 0:
|
||||||
|
for i in required_inputs:
|
||||||
|
execution_list.make_input_strong_link(unique_id, i)
|
||||||
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
|
def execution_block_cb(block):
|
||||||
|
if block.message is not None:
|
||||||
|
mes = {
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"node_id": unique_id,
|
||||||
|
"node_type": class_type,
|
||||||
|
"executed": list(executed),
|
||||||
|
|
||||||
|
"exception_message": f"Execution Blocked: {block.message}",
|
||||||
|
"exception_type": "ExecutionBlocked",
|
||||||
|
"traceback": [],
|
||||||
|
"current_inputs": [],
|
||||||
|
"current_outputs": [],
|
||||||
|
}
|
||||||
|
server.send_sync("execution_error", mes, server.client_id)
|
||||||
|
return ExecutionBlocker(None)
|
||||||
|
else:
|
||||||
|
return block
|
||||||
|
def pre_execute_cb(call_index):
|
||||||
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
|
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||||
|
if len(output_ui) > 0:
|
||||||
|
caches.ui.set(unique_id, {
|
||||||
|
"meta": {
|
||||||
|
"node_id": unique_id,
|
||||||
|
"display_node": display_node_id,
|
||||||
|
"parent_node": parent_node_id,
|
||||||
|
"real_node_id": real_node_id,
|
||||||
|
},
|
||||||
|
"output": output_ui
|
||||||
|
})
|
||||||
|
if server.client_id is not None:
|
||||||
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
if has_subgraph:
|
||||||
|
cached_outputs = []
|
||||||
|
new_node_ids = []
|
||||||
|
new_output_ids = []
|
||||||
|
new_output_links = []
|
||||||
|
for i in range(len(output_data)):
|
||||||
|
new_graph, node_outputs = output_data[i]
|
||||||
|
if new_graph is None:
|
||||||
|
cached_outputs.append((False, node_outputs))
|
||||||
|
else:
|
||||||
|
# Check for conflicts
|
||||||
|
for node_id in new_graph.keys():
|
||||||
|
if dynprompt.has_node(node_id):
|
||||||
|
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
|
||||||
|
for node_id, node_info in new_graph.items():
|
||||||
|
new_node_ids.append(node_id)
|
||||||
|
display_id = node_info.get("override_display_id", unique_id)
|
||||||
|
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
|
||||||
|
# Figure out if the newly created node is an output node
|
||||||
|
class_type = node_info["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||||
|
new_output_ids.append(node_id)
|
||||||
|
for i in range(len(node_outputs)):
|
||||||
|
if is_link(node_outputs[i]):
|
||||||
|
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||||
|
new_output_links.append((from_node_id, from_socket))
|
||||||
|
cached_outputs.append((True, node_outputs))
|
||||||
|
new_node_ids = set(new_node_ids)
|
||||||
|
for cache in caches.all:
|
||||||
|
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||||
|
for node_id in new_output_ids:
|
||||||
|
execution_list.add_node(node_id)
|
||||||
|
for link in new_output_links:
|
||||||
|
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||||
|
pending_subgraph_results[unique_id] = cached_outputs
|
||||||
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
caches.outputs.set(unique_id, output_data)
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
|
|
||||||
# skip formatting inputs/outputs
|
# skip formatting inputs/outputs
|
||||||
error_details = {
|
error_details = {
|
||||||
"node_id": unique_id,
|
"node_id": real_node_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
return (False, error_details, iex)
|
return (ExecutionResult.FAILURE, error_details, iex)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
typ, _, tb = sys.exc_info()
|
typ, _, tb = sys.exc_info()
|
||||||
exception_type = full_type_name(typ)
|
exception_type = full_type_name(typ)
|
||||||
@@ -173,121 +383,36 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
for name, inputs in input_data_all.items():
|
for name, inputs in input_data_all.items():
|
||||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||||
|
|
||||||
output_data_formatted = {}
|
logging.error(f"!!! Exception during processing !!! {ex}")
|
||||||
for node_id, node_outputs in outputs.items():
|
|
||||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
|
||||||
|
|
||||||
logging.error(f"!!! Exception during processing!!! {ex}")
|
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
error_details = {
|
error_details = {
|
||||||
"node_id": unique_id,
|
"node_id": real_node_id,
|
||||||
"exception_message": str(ex),
|
"exception_message": str(ex),
|
||||||
"exception_type": exception_type,
|
"exception_type": exception_type,
|
||||||
"traceback": traceback.format_tb(tb),
|
"traceback": traceback.format_tb(tb),
|
||||||
"current_inputs": input_data_formatted,
|
"current_inputs": input_data_formatted
|
||||||
"current_outputs": output_data_formatted
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|
||||||
return (False, error_details, ex)
|
return (ExecutionResult.FAILURE, error_details, ex)
|
||||||
|
|
||||||
executed.add(unique_id)
|
executed.add(unique_id)
|
||||||
|
|
||||||
return (True, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
def recursive_will_execute(prompt, outputs, current_item, memo={}):
|
|
||||||
unique_id = current_item
|
|
||||||
|
|
||||||
if unique_id in memo:
|
|
||||||
return memo[unique_id]
|
|
||||||
|
|
||||||
inputs = prompt[unique_id]['inputs']
|
|
||||||
will_execute = []
|
|
||||||
if unique_id in outputs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
for x in inputs:
|
|
||||||
input_data = inputs[x]
|
|
||||||
if isinstance(input_data, list):
|
|
||||||
input_unique_id = input_data[0]
|
|
||||||
output_index = input_data[1]
|
|
||||||
if input_unique_id not in outputs:
|
|
||||||
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
|
|
||||||
|
|
||||||
memo[unique_id] = will_execute + [unique_id]
|
|
||||||
return memo[unique_id]
|
|
||||||
|
|
||||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
|
||||||
unique_id = current_item
|
|
||||||
inputs = prompt[unique_id]['inputs']
|
|
||||||
class_type = prompt[unique_id]['class_type']
|
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
||||||
|
|
||||||
is_changed_old = ''
|
|
||||||
is_changed = ''
|
|
||||||
to_delete = False
|
|
||||||
if hasattr(class_def, 'IS_CHANGED'):
|
|
||||||
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
|
||||||
is_changed_old = old_prompt[unique_id]['is_changed']
|
|
||||||
if 'is_changed' not in prompt[unique_id]:
|
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
|
||||||
if input_data_all is not None:
|
|
||||||
try:
|
|
||||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
|
||||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
|
||||||
except:
|
|
||||||
to_delete = True
|
|
||||||
else:
|
|
||||||
is_changed = prompt[unique_id]['is_changed']
|
|
||||||
|
|
||||||
if unique_id not in outputs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not to_delete:
|
|
||||||
if is_changed != is_changed_old:
|
|
||||||
to_delete = True
|
|
||||||
elif unique_id not in old_prompt:
|
|
||||||
to_delete = True
|
|
||||||
elif class_type != old_prompt[unique_id]['class_type']:
|
|
||||||
to_delete = True
|
|
||||||
elif inputs == old_prompt[unique_id]['inputs']:
|
|
||||||
for x in inputs:
|
|
||||||
input_data = inputs[x]
|
|
||||||
|
|
||||||
if isinstance(input_data, list):
|
|
||||||
input_unique_id = input_data[0]
|
|
||||||
output_index = input_data[1]
|
|
||||||
if input_unique_id in outputs:
|
|
||||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
|
||||||
else:
|
|
||||||
to_delete = True
|
|
||||||
if to_delete:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
to_delete = True
|
|
||||||
|
|
||||||
if to_delete:
|
|
||||||
d = outputs.pop(unique_id)
|
|
||||||
del d
|
|
||||||
return to_delete
|
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server, lru_size=None):
|
||||||
|
self.lru_size = lru_size
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.outputs = {}
|
self.caches = CacheSet(self.lru_size)
|
||||||
self.object_storage = {}
|
|
||||||
self.outputs_ui = {}
|
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
self.old_prompt = {}
|
|
||||||
|
|
||||||
def add_message(self, event, data: dict, broadcast: bool):
|
def add_message(self, event, data: dict, broadcast: bool):
|
||||||
data = {
|
data = {
|
||||||
@@ -318,26 +443,13 @@ class PromptExecutor:
|
|||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
|
|
||||||
"exception_message": error["exception_message"],
|
"exception_message": error["exception_message"],
|
||||||
"exception_type": error["exception_type"],
|
"exception_type": error["exception_type"],
|
||||||
"traceback": error["traceback"],
|
"traceback": error["traceback"],
|
||||||
"current_inputs": error["current_inputs"],
|
"current_inputs": error["current_inputs"],
|
||||||
"current_outputs": error["current_outputs"],
|
"current_outputs": list(current_outputs),
|
||||||
}
|
}
|
||||||
self.add_message("execution_error", mes, broadcast=False)
|
self.add_message("execution_error", mes, broadcast=False)
|
||||||
|
|
||||||
# Next, remove the subsequent outputs since they will not be executed
|
|
||||||
to_delete = []
|
|
||||||
for o in self.outputs:
|
|
||||||
if (o not in current_outputs) and (o not in executed):
|
|
||||||
to_delete += [o]
|
|
||||||
if o in self.old_prompt:
|
|
||||||
d = self.old_prompt.pop(o)
|
|
||||||
del d
|
|
||||||
for o in to_delete:
|
|
||||||
d = self.outputs.pop(o)
|
|
||||||
del d
|
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
@@ -351,65 +463,59 @@ class PromptExecutor:
|
|||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
#delete cached outputs if nodes don't exist for them
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
to_delete = []
|
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||||
for o in self.outputs:
|
for cache in self.caches.all:
|
||||||
if o not in prompt:
|
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||||
to_delete += [o]
|
cache.clean_unused()
|
||||||
for o in to_delete:
|
|
||||||
d = self.outputs.pop(o)
|
|
||||||
del d
|
|
||||||
to_delete = []
|
|
||||||
for o in self.object_storage:
|
|
||||||
if o[0] not in prompt:
|
|
||||||
to_delete += [o]
|
|
||||||
else:
|
|
||||||
p = prompt[o[0]]
|
|
||||||
if o[1] != p['class_type']:
|
|
||||||
to_delete += [o]
|
|
||||||
for o in to_delete:
|
|
||||||
d = self.object_storage.pop(o)
|
|
||||||
del d
|
|
||||||
|
|
||||||
for x in prompt:
|
cached_nodes = []
|
||||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
for node_id in prompt:
|
||||||
|
if self.caches.outputs.get(node_id) is not None:
|
||||||
current_outputs = set(self.outputs.keys())
|
cached_nodes.append(node_id)
|
||||||
for x in list(self.outputs_ui.keys()):
|
|
||||||
if x not in current_outputs:
|
|
||||||
d = self.outputs_ui.pop(x)
|
|
||||||
del d
|
|
||||||
|
|
||||||
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||||
self.add_message("execution_cached",
|
self.add_message("execution_cached",
|
||||||
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
|
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||||
broadcast=False)
|
broadcast=False)
|
||||||
|
pending_subgraph_results = {}
|
||||||
executed = set()
|
executed = set()
|
||||||
output_node_id = None
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
to_execute = []
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
|
|
||||||
for node_id in list(execute_outputs):
|
for node_id in list(execute_outputs):
|
||||||
to_execute += [(0, node_id)]
|
execution_list.add_node(node_id)
|
||||||
|
|
||||||
while len(to_execute) > 0:
|
while not execution_list.is_empty():
|
||||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
node_id, error, ex = execution_list.stage_node_execution()
|
||||||
memo = {}
|
if error is not None:
|
||||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
output_node_id = to_execute.pop(0)[-1]
|
|
||||||
|
|
||||||
# This call shouldn't raise anything if there's an error deep in
|
|
||||||
# the actual SD code, instead it will report the node where the
|
|
||||||
# error was raised
|
|
||||||
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
|
||||||
if self.success is not True:
|
|
||||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||||
|
self.success = result != ExecutionResult.FAILURE
|
||||||
|
if result == ExecutionResult.FAILURE:
|
||||||
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
|
break
|
||||||
|
elif result == ExecutionResult.PENDING:
|
||||||
|
execution_list.unstage_node_execution()
|
||||||
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
|
execution_list.complete_node_execution()
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
for x in executed:
|
ui_outputs = {}
|
||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
meta_outputs = {}
|
||||||
|
all_node_ids = self.caches.ui.all_node_ids()
|
||||||
|
for node_id in all_node_ids:
|
||||||
|
ui_info = self.caches.ui.get(node_id)
|
||||||
|
if ui_info is not None:
|
||||||
|
ui_outputs[node_id] = ui_info["output"]
|
||||||
|
meta_outputs[node_id] = ui_info["meta"]
|
||||||
|
self.history_result = {
|
||||||
|
"outputs": ui_outputs,
|
||||||
|
"meta": meta_outputs,
|
||||||
|
}
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
@@ -426,31 +532,37 @@ def validate_inputs(prompt, item, validated):
|
|||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
|
||||||
class_inputs = obj_class.INPUT_TYPES()
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
required_inputs = class_inputs['required']
|
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
|
validate_has_kwargs = False
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
||||||
|
validate_function_inputs = argspec.args
|
||||||
|
validate_has_kwargs = argspec.varkw is not None
|
||||||
|
received_types = {}
|
||||||
|
|
||||||
for x in required_inputs:
|
for x in valid_inputs:
|
||||||
|
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
||||||
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
error = {
|
if input_category == "required":
|
||||||
"type": "required_input_missing",
|
error = {
|
||||||
"message": "Required input is missing",
|
"type": "required_input_missing",
|
||||||
"details": f"{x}",
|
"message": "Required input is missing",
|
||||||
"extra_info": {
|
"details": f"{x}",
|
||||||
"input_name": x
|
"extra_info": {
|
||||||
|
"input_name": x
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
errors.append(error)
|
||||||
errors.append(error)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = required_inputs[x]
|
info = (type_input, extra_info)
|
||||||
type_input = info[0]
|
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
error = {
|
error = {
|
||||||
@@ -469,8 +581,9 @@ def validate_inputs(prompt, item, validated):
|
|||||||
o_id = val[0]
|
o_id = val[0]
|
||||||
o_class_type = prompt[o_id]['class_type']
|
o_class_type = prompt[o_id]['class_type']
|
||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
if r[val[1]] != type_input:
|
received_type = r[val[1]]
|
||||||
received_type = r[val[1]]
|
received_types[x] = received_type
|
||||||
|
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||||
details = f"{x}, {received_type} != {type_input}"
|
details = f"{x}, {received_type} != {type_input}"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
@@ -521,6 +634,9 @@ def validate_inputs(prompt, item, validated):
|
|||||||
if type_input == "STRING":
|
if type_input == "STRING":
|
||||||
val = str(val)
|
val = str(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
if type_input == "BOOLEAN":
|
||||||
|
val = bool(val)
|
||||||
|
inputs[x] = val
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = {
|
error = {
|
||||||
"type": "invalid_input_type",
|
"type": "invalid_input_type",
|
||||||
@@ -536,11 +652,11 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(info) > 1:
|
if x not in validate_function_inputs and not validate_has_kwargs:
|
||||||
if "min" in info[1] and val < info[1]["min"]:
|
if "min" in extra_info and val < extra_info["min"]:
|
||||||
error = {
|
error = {
|
||||||
"type": "value_smaller_than_min",
|
"type": "value_smaller_than_min",
|
||||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
|
||||||
"details": f"{x}",
|
"details": f"{x}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@@ -550,10 +666,10 @@ def validate_inputs(prompt, item, validated):
|
|||||||
}
|
}
|
||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
if "max" in info[1] and val > info[1]["max"]:
|
if "max" in extra_info and val > extra_info["max"]:
|
||||||
error = {
|
error = {
|
||||||
"type": "value_bigger_than_max",
|
"type": "value_bigger_than_max",
|
||||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
|
||||||
"details": f"{x}",
|
"details": f"{x}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@@ -564,7 +680,6 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if x not in validate_function_inputs:
|
|
||||||
if isinstance(type_input, list):
|
if isinstance(type_input, list):
|
||||||
if val not in type_input:
|
if val not in type_input:
|
||||||
input_config = info
|
input_config = info
|
||||||
@@ -591,18 +706,20 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(validate_function_inputs) > 0:
|
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||||
input_filtered = {}
|
input_filtered = {}
|
||||||
for x in input_data_all:
|
for x in input_data_all:
|
||||||
if x in validate_function_inputs:
|
if x in validate_function_inputs or validate_has_kwargs:
|
||||||
input_filtered[x] = input_data_all[x]
|
input_filtered[x] = input_data_all[x]
|
||||||
|
if 'input_types' in validate_function_inputs:
|
||||||
|
input_filtered['input_types'] = [received_types]
|
||||||
|
|
||||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||||
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||||
for x in input_filtered:
|
for x in input_filtered:
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
if r is not True:
|
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||||
details = f"{x}"
|
details = f"{x}"
|
||||||
if r is not False:
|
if r is not False:
|
||||||
details += f" - {str(r)}"
|
details += f" - {str(r)}"
|
||||||
@@ -613,8 +730,6 @@ def validate_inputs(prompt, item, validated):
|
|||||||
"details": details,
|
"details": details,
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
"input_config": info,
|
|
||||||
"received_value": val,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
errors.append(error)
|
errors.append(error)
|
||||||
@@ -780,7 +895,7 @@ class PromptQueue:
|
|||||||
completed: bool
|
completed: bool
|
||||||
messages: List[str]
|
messages: List[str]
|
||||||
|
|
||||||
def task_done(self, item_id, outputs,
|
def task_done(self, item_id, history_result,
|
||||||
status: Optional['PromptQueue.ExecutionStatus']):
|
status: Optional['PromptQueue.ExecutionStatus']):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
prompt = self.currently_running.pop(item_id)
|
prompt = self.currently_running.pop(item_id)
|
||||||
@@ -793,9 +908,10 @@ class PromptQueue:
|
|||||||
|
|
||||||
self.history[prompt[1]] = {
|
self.history[prompt[1]] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"outputs": copy.deepcopy(outputs),
|
"outputs": {},
|
||||||
'status': status_dict,
|
'status': status_dict,
|
||||||
}
|
}
|
||||||
|
self.history[prompt[1]].update(history_result)
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
|
||||||
def get_current_queue(self):
|
def get_current_queue(self):
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
|||||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||||
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
|
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||||
@@ -44,6 +44,10 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
|
|||||||
|
|
||||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
|
||||||
|
def map_legacy(folder_name: str) -> str:
|
||||||
|
legacy = {"unet": "diffusion_models"}
|
||||||
|
return legacy.get(folder_name, folder_name)
|
||||||
|
|
||||||
if not os.path.exists(input_directory):
|
if not os.path.exists(input_directory):
|
||||||
try:
|
try:
|
||||||
os.makedirs(input_directory)
|
os.makedirs(input_directory)
|
||||||
@@ -128,12 +132,14 @@ def exists_annotated_filepath(name) -> bool:
|
|||||||
|
|
||||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||||
else:
|
else:
|
||||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||||
|
|
||||||
def get_folder_paths(folder_name: str) -> list[str]:
|
def get_folder_paths(folder_name: str) -> list[str]:
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
return folder_names_and_paths[folder_name][0][:]
|
return folder_names_and_paths[folder_name][0][:]
|
||||||
|
|
||||||
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
|
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
|
||||||
@@ -180,6 +186,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
|||||||
|
|
||||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in folder_names_and_paths:
|
if folder_name not in folder_names_and_paths:
|
||||||
return None
|
return None
|
||||||
folders = folder_names_and_paths[folder_name]
|
folders = folder_names_and_paths[folder_name]
|
||||||
@@ -194,6 +201,7 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
output_list = set()
|
output_list = set()
|
||||||
folders = folder_names_and_paths[folder_name]
|
folders = folder_names_and_paths[folder_name]
|
||||||
@@ -208,6 +216,7 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
|
|||||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in filename_list_cache:
|
if folder_name not in filename_list_cache:
|
||||||
return None
|
return None
|
||||||
out = filename_list_cache[folder_name]
|
out = filename_list_cache[folder_name]
|
||||||
@@ -227,6 +236,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def get_filename_list(folder_name: str) -> list[str]:
|
def get_filename_list(folder_name: str) -> list[str]:
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
out = cached_filename_list_(folder_name)
|
out = cached_filename_list_(folder_name)
|
||||||
if out is None:
|
if out is None:
|
||||||
out = get_filename_list_(folder_name)
|
out = get_filename_list_(folder_name)
|
||||||
|
|||||||
10
main.py
10
main.py
@@ -6,6 +6,10 @@ import importlib.util
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from app.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
setup_logger(verbose=args.verbose)
|
||||||
|
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
@@ -101,7 +105,7 @@ def cuda_malloc_warning():
|
|||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
def prompt_worker(q, server):
|
def prompt_worker(q, server):
|
||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
@@ -121,7 +125,7 @@ def prompt_worker(q, server):
|
|||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
e.outputs_ui,
|
e.history_result,
|
||||||
status=execution.PromptQueue.ExecutionStatus(
|
status=execution.PromptQueue.ExecutionStatus(
|
||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
@@ -242,6 +246,7 @@ if __name__ == "__main__":
|
|||||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
|
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
@@ -261,6 +266,7 @@ if __name__ == "__main__":
|
|||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
loop.run_until_complete(server.setup())
|
||||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
|||||||
2
model_filemanager/__init__.py
Normal file
2
model_filemanager/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# model_manager/__init__.py
|
||||||
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
||||||
240
model_filemanager/download_models.py
Normal file
240
model_filemanager/download_models.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import aiohttp
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
import logging
|
||||||
|
from folder_paths import models_dir
|
||||||
|
import re
|
||||||
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
|
from enum import Enum
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadStatusType(Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadModelStatus():
|
||||||
|
status: str
|
||||||
|
progress_percentage: float
|
||||||
|
message: str
|
||||||
|
already_existed: bool = False
|
||||||
|
|
||||||
|
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
|
||||||
|
self.status = status.value # Store the string value of the Enum
|
||||||
|
self.progress_percentage = progress_percentage
|
||||||
|
self.message = message
|
||||||
|
self.already_existed = already_existed
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"status": self.status,
|
||||||
|
"progress_percentage": self.progress_percentage,
|
||||||
|
"message": self.message,
|
||||||
|
"already_existed": self.already_existed
|
||||||
|
}
|
||||||
|
|
||||||
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
|
model_name: str,
|
||||||
|
model_url: str,
|
||||||
|
model_sub_directory: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
|
"""
|
||||||
|
Download a model file from a given URL into the models directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||||
|
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||||
|
model_name (str):
|
||||||
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
|
model_url (str):
|
||||||
|
The URL from which to download the model.
|
||||||
|
model_sub_directory (str):
|
||||||
|
The subdirectory within the main models directory where the model
|
||||||
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
|
An asynchronous function to call with progress updates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DownloadModelStatus: The result of the download operation.
|
||||||
|
"""
|
||||||
|
if not validate_model_subdirectory(model_sub_directory):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid model subdirectory",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not validate_filename(model_name):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid model name",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
||||||
|
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
||||||
|
if existing_file:
|
||||||
|
return existing_file
|
||||||
|
|
||||||
|
try:
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
|
response = await model_download_request(model_url)
|
||||||
|
if response.status != 200:
|
||||||
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
|
logging.error(error_message)
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
|
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in downloading model: {e}")
|
||||||
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
||||||
|
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||||
|
os.makedirs(full_model_dir, exist_ok=True)
|
||||||
|
file_path = os.path.join(full_model_dir, model_name)
|
||||||
|
|
||||||
|
# Ensure the resulting path is still within the base directory
|
||||||
|
abs_file_path = os.path.abspath(file_path)
|
||||||
|
abs_base_dir = os.path.abspath(str(models_base_dir))
|
||||||
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
|
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
relative_path = '/'.join([model_directory, model_name])
|
||||||
|
return file_path, relative_path
|
||||||
|
|
||||||
|
async def check_file_exists(file_path: str,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
relative_path: str) -> Optional[DownloadModelStatus]:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return status
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||||
|
file_path: str,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
relative_path: str,
|
||||||
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
|
try:
|
||||||
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
|
downloaded = 0
|
||||||
|
last_update_time = time.time()
|
||||||
|
|
||||||
|
async def update_progress():
|
||||||
|
nonlocal last_update_time
|
||||||
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
last_update_time = time.time()
|
||||||
|
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
chunk_iterator = response.content.iter_chunked(8192)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = await chunk_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
|
||||||
|
if time.time() - last_update_time >= interval:
|
||||||
|
await update_progress()
|
||||||
|
|
||||||
|
await update_progress()
|
||||||
|
|
||||||
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
|
return status
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
|
async def handle_download_error(e: Exception,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||||
|
relative_path: str) -> DownloadModelStatus:
|
||||||
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return status
|
||||||
|
|
||||||
|
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that the model subdirectory is safe to install into.
|
||||||
|
Must not contain relative paths, nested paths or special characters
|
||||||
|
other than underscores and hyphens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_subdirectory (str): The subdirectory for the specific model type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the subdirectory is safe, False otherwise.
|
||||||
|
"""
|
||||||
|
if len(model_subdirectory) > 50:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if '..' in model_subdirectory or '/' in model_subdirectory:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_filename(filename: str)-> bool:
|
||||||
|
"""
|
||||||
|
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The filename to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the filename is valid, False otherwise
|
||||||
|
"""
|
||||||
|
if not filename.lower().endswith(('.sft', '.safetensors')):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the filename is empty, None, or just whitespace
|
||||||
|
if not filename or not filename.strip():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for any directory traversal attempts or invalid characters
|
||||||
|
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the filename starts with a dot (hidden file)
|
||||||
|
if filename.startswith('.'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Use a whitelist of allowed characters
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Ensure the filename isn't too long
|
||||||
|
if len(filename) > 255:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
130
nodes.py
130
nodes.py
@@ -47,11 +47,18 @@ MAX_RESOLUTION=16384
|
|||||||
class CLIPTextEncode:
|
class CLIPTextEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}}
|
return {
|
||||||
|
"required": {
|
||||||
|
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||||
|
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."})
|
||||||
|
}
|
||||||
|
}
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||||
|
|
||||||
def encode(self, clip, text):
|
def encode(self, clip, text):
|
||||||
tokens = clip.tokenize(text)
|
tokens = clip.tokenize(text)
|
||||||
@@ -260,11 +267,18 @@ class ConditioningSetTimestepRange:
|
|||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return {
|
||||||
|
"required": {
|
||||||
|
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
|
||||||
|
"vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."})
|
||||||
|
}
|
||||||
|
}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
OUTPUT_TOOLTIPS = ("The decoded image.",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "latent"
|
CATEGORY = "latent"
|
||||||
|
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
return (vae.decode(samples["samples"]), )
|
return (vae.decode(samples["samples"]), )
|
||||||
@@ -506,12 +520,19 @@ class CheckpointLoader:
|
|||||||
class CheckpointLoaderSimple:
|
class CheckpointLoaderSimple:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
return {
|
||||||
}}
|
"required": {
|
||||||
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
|
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
|
||||||
|
"The CLIP model used for encoding text prompts.",
|
||||||
|
"The VAE model used for encoding and decoding images to and from latent space.")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name):
|
def load_checkpoint(self, ckpt_name):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||||
@@ -582,16 +603,22 @@ class LoraLoader:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {
|
||||||
"clip": ("CLIP", ),
|
"required": {
|
||||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
|
||||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
||||||
}}
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||||
|
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL", "CLIP")
|
RETURN_TYPES = ("MODEL", "CLIP")
|
||||||
|
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
|
||||||
FUNCTION = "load_lora"
|
FUNCTION = "load_lora"
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
|
||||||
|
|
||||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
@@ -638,6 +665,8 @@ class VAELoader:
|
|||||||
sd1_taesd_dec = False
|
sd1_taesd_dec = False
|
||||||
sd3_taesd_enc = False
|
sd3_taesd_enc = False
|
||||||
sd3_taesd_dec = False
|
sd3_taesd_dec = False
|
||||||
|
f1_taesd_enc = False
|
||||||
|
f1_taesd_dec = False
|
||||||
|
|
||||||
for v in approx_vaes:
|
for v in approx_vaes:
|
||||||
if v.startswith("taesd_decoder."):
|
if v.startswith("taesd_decoder."):
|
||||||
@@ -652,12 +681,18 @@ class VAELoader:
|
|||||||
sd3_taesd_dec = True
|
sd3_taesd_dec = True
|
||||||
elif v.startswith("taesd3_encoder."):
|
elif v.startswith("taesd3_encoder."):
|
||||||
sd3_taesd_enc = True
|
sd3_taesd_enc = True
|
||||||
|
elif v.startswith("taef1_encoder."):
|
||||||
|
f1_taesd_dec = True
|
||||||
|
elif v.startswith("taef1_decoder."):
|
||||||
|
f1_taesd_enc = True
|
||||||
if sd1_taesd_dec and sd1_taesd_enc:
|
if sd1_taesd_dec and sd1_taesd_enc:
|
||||||
vaes.append("taesd")
|
vaes.append("taesd")
|
||||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||||
vaes.append("taesdxl")
|
vaes.append("taesdxl")
|
||||||
if sd3_taesd_dec and sd3_taesd_enc:
|
if sd3_taesd_dec and sd3_taesd_enc:
|
||||||
vaes.append("taesd3")
|
vaes.append("taesd3")
|
||||||
|
if f1_taesd_dec and f1_taesd_enc:
|
||||||
|
vaes.append("taef1")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -685,6 +720,9 @@ class VAELoader:
|
|||||||
elif name == "taesd3":
|
elif name == "taesd3":
|
||||||
sd["vae_scale"] = torch.tensor(1.5305)
|
sd["vae_scale"] = torch.tensor(1.5305)
|
||||||
sd["vae_shift"] = torch.tensor(0.0609)
|
sd["vae_shift"] = torch.tensor(0.0609)
|
||||||
|
elif name == "taef1":
|
||||||
|
sd["vae_scale"] = torch.tensor(0.3611)
|
||||||
|
sd["vae_shift"] = torch.tensor(0.1159)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -697,7 +735,7 @@ class VAELoader:
|
|||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
if vae_name in ["taesd", "taesdxl", "taesd3"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||||
@@ -817,7 +855,7 @@ class ControlNetApplyAdvanced:
|
|||||||
class UNETLoader:
|
class UNETLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
@@ -826,14 +864,14 @@ class UNETLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_unet(self, unet_name, weight_dtype):
|
def load_unet(self, unet_name, weight_dtype):
|
||||||
dtype = None
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
dtype = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
dtype = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
|
||||||
model = comfy.sd.load_unet(unet_path, dtype=dtype)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
@@ -1033,13 +1071,19 @@ class EmptyLatentImage:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
return {
|
||||||
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
"required": {
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
"width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}),
|
||||||
|
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."})
|
||||||
|
}
|
||||||
|
}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
OUTPUT_TOOLTIPS = ("The empty latent image batch.",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
CATEGORY = "latent"
|
CATEGORY = "latent"
|
||||||
|
DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling."
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||||
@@ -1359,24 +1403,27 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
class KSampler:
|
class KSampler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {
|
||||||
{"model": ("MODEL",),
|
"required": {
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
|
||||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
|
||||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
|
||||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
|
||||||
"positive": ("CONDITIONING", ),
|
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}),
|
||||||
"negative": ("CONDITIONING", ),
|
"positive": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to include in the image."}),
|
||||||
"latent_image": ("LATENT", ),
|
"negative": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to exclude from the image."}),
|
||||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}),
|
||||||
}
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."}),
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
OUTPUT_TOOLTIPS = ("The denoised latent.",)
|
||||||
FUNCTION = "sample"
|
FUNCTION = "sample"
|
||||||
|
|
||||||
CATEGORY = "sampling"
|
CATEGORY = "sampling"
|
||||||
|
DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image."
|
||||||
|
|
||||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
|
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
|
||||||
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
|
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
|
||||||
@@ -1424,11 +1471,15 @@ class SaveImage:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {
|
||||||
{"images": ("IMAGE", ),
|
"required": {
|
||||||
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
|
"images": ("IMAGE", {"tooltip": "The images to save."}),
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||||
}
|
},
|
||||||
|
"hidden": {
|
||||||
|
"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
RETURN_TYPES = ()
|
||||||
FUNCTION = "save_images"
|
FUNCTION = "save_images"
|
||||||
@@ -1436,6 +1487,7 @@ class SaveImage:
|
|||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||||
|
|
||||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
@@ -2077,3 +2129,5 @@ def init_extra_nodes(init_custom_nodes=True):
|
|||||||
else:
|
else:
|
||||||
logging.warning("Please do a: pip install -r requirements.txt")
|
logging.warning("Please do a: pip install -r requirements.txt")
|
||||||
logging.warning("")
|
logging.warning("")
|
||||||
|
|
||||||
|
return import_failed
|
||||||
|
|||||||
@@ -79,7 +79,7 @@
|
|||||||
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD1.5\n",
|
"# SD1.5\n",
|
||||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD2\n",
|
"# SD2\n",
|
||||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
markers =
|
markers =
|
||||||
inference: mark as inference test (deselect with '-m "not inference"')
|
inference: mark as inference test (deselect with '-m "not inference"')
|
||||||
|
execution: mark as execution test (deselect with '-m "not execution"')
|
||||||
testpaths =
|
testpaths =
|
||||||
tests
|
tests
|
||||||
tests-unit
|
tests-unit
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
|||||||
@@ -41,15 +41,14 @@ def get_images(ws, prompt):
|
|||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = get_history(prompt_id)[prompt_id]
|
history = get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
for node_id in history['outputs']:
|
node_output = history['outputs'][node_id]
|
||||||
node_output = history['outputs'][node_id]
|
images_output = []
|
||||||
if 'images' in node_output:
|
if 'images' in node_output:
|
||||||
images_output = []
|
for image in node_output['images']:
|
||||||
for image in node_output['images']:
|
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
images_output.append(image_data)
|
||||||
images_output.append(image_data)
|
output_images[node_id] = images_output
|
||||||
output_images[node_id] = images_output
|
|
||||||
|
|
||||||
return output_images
|
return output_images
|
||||||
|
|
||||||
@@ -85,7 +84,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
|||||||
83
server.py
83
server.py
@@ -12,7 +12,6 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
import hashlib
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -28,7 +27,9 @@ import comfy.model_management
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
|
from model_filemanager import download_model, DownloadModelStatus
|
||||||
|
from typing import Optional
|
||||||
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
@@ -40,6 +41,21 @@ async def send_socket_catch_exception(function, message):
|
|||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||||
logging.warning("send error: {}".format(err))
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
|
def get_comfyui_version():
|
||||||
|
comfyui_version = "unknown"
|
||||||
|
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
try:
|
||||||
|
import pygit2
|
||||||
|
repo = pygit2.Repository(repo_path)
|
||||||
|
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to get ComfyUI version: {e}")
|
||||||
|
return comfyui_version.strip()
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
@@ -72,10 +88,12 @@ class PromptServer():
|
|||||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||||
|
|
||||||
self.user_manager = UserManager()
|
self.user_manager = UserManager()
|
||||||
|
self.internal_routes = InternalRoutes()
|
||||||
self.supports = ["custom_nodes_from_web"]
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
|
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||||
self.number = 0
|
self.number = 0
|
||||||
|
|
||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
@@ -138,6 +156,14 @@ class PromptServer():
|
|||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
|
@routes.get("/models/{folder}")
|
||||||
|
async def get_models(request):
|
||||||
|
folder = request.match_info.get("folder", None)
|
||||||
|
if not folder in folder_paths.folder_names_and_paths:
|
||||||
|
return web.Response(status=404)
|
||||||
|
files = folder_paths.get_filename_list(folder)
|
||||||
|
return web.json_response(files)
|
||||||
|
|
||||||
@routes.get("/extensions")
|
@routes.get("/extensions")
|
||||||
async def get_extensions(request):
|
async def get_extensions(request):
|
||||||
files = glob.glob(os.path.join(
|
files = glob.glob(os.path.join(
|
||||||
@@ -389,16 +415,20 @@ class PromptServer():
|
|||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def get_queue(request):
|
async def system_stats(request):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": os.name,
|
||||||
|
"comfyui_version": get_comfyui_version(),
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
|
"argv": sys.argv
|
||||||
},
|
},
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
@@ -422,6 +452,7 @@ class PromptServer():
|
|||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||||
info = {}
|
info = {}
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
|
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||||
info['output'] = obj_class.RETURN_TYPES
|
info['output'] = obj_class.RETURN_TYPES
|
||||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||||
@@ -437,6 +468,14 @@ class PromptServer():
|
|||||||
|
|
||||||
if hasattr(obj_class, 'CATEGORY'):
|
if hasattr(obj_class, 'CATEGORY'):
|
||||||
info['category'] = obj_class.CATEGORY
|
info['category'] = obj_class.CATEGORY
|
||||||
|
|
||||||
|
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
|
||||||
|
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
|
||||||
|
|
||||||
|
if getattr(obj_class, "DEPRECATED", False):
|
||||||
|
info['deprecated'] = True
|
||||||
|
if getattr(obj_class, "EXPERIMENTAL", False):
|
||||||
|
info['experimental'] = True
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
@@ -559,9 +598,42 @@ class PromptServer():
|
|||||||
self.prompt_queue.delete_history_item(id_to_delete)
|
self.prompt_queue.delete_history_item(id_to_delete)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||||
|
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||||
|
@routes.post("/internal/models/download")
|
||||||
|
async def download_handler(request):
|
||||||
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
|
payload = status.to_dict()
|
||||||
|
payload['download_path'] = filename
|
||||||
|
await self.send_json("download_progress", payload)
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
url = data.get('url')
|
||||||
|
model_directory = data.get('model_directory')
|
||||||
|
model_filename = data.get('model_filename')
|
||||||
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
|
if not url or not model_directory or not model_filename:
|
||||||
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
|
session = self.client_session
|
||||||
|
if session is None:
|
||||||
|
logging.error("Client session is not initialized")
|
||||||
|
return web.Response(status=500)
|
||||||
|
|
||||||
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
||||||
|
await task
|
||||||
|
|
||||||
|
return web.json_response(task.result().to_dict())
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
|
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||||
|
|
||||||
# Prefix every route with /api for easier matching for delegation.
|
# Prefix every route with /api for easier matching for delegation.
|
||||||
# This is very useful for frontend dev server, which need to forward
|
# This is very useful for frontend dev server, which need to forward
|
||||||
@@ -680,6 +752,9 @@ class PromptServer():
|
|||||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
|
self.address = address
|
||||||
|
self.port = port
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.info("Starting server\n")
|
logging.info("Starting server\n")
|
||||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
||||||
|
|||||||
1
tests-ui/.gitignore
vendored
1
tests-ui/.gitignore
vendored
@@ -1 +0,0 @@
|
|||||||
node_modules
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
const { start } = require("./utils");
|
|
||||||
const lg = require("./utils/litegraph");
|
|
||||||
|
|
||||||
// Load things once per test file before to ensure its all warmed up for the tests
|
|
||||||
beforeAll(async () => {
|
|
||||||
lg.setup(global);
|
|
||||||
await start({ resetEnv: true });
|
|
||||||
lg.teardown(global);
|
|
||||||
});
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
{
|
|
||||||
"presets": ["@babel/preset-env"],
|
|
||||||
"plugins": ["babel-plugin-transform-import-meta"]
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
module.exports = async function () {
|
|
||||||
global.ResizeObserver = class ResizeObserver {
|
|
||||||
observe() {}
|
|
||||||
unobserve() {}
|
|
||||||
disconnect() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
const { nop } = require("./utils/nopProxy");
|
|
||||||
global.enableWebGLCanvas = nop;
|
|
||||||
|
|
||||||
HTMLCanvasElement.prototype.getContext = nop;
|
|
||||||
|
|
||||||
localStorage["Comfy.Settings.Comfy.Logging.Enabled"] = "false";
|
|
||||||
};
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
/** @type {import('jest').Config} */
|
|
||||||
const config = {
|
|
||||||
testEnvironment: "jsdom",
|
|
||||||
setupFiles: ["./globalSetup.js"],
|
|
||||||
setupFilesAfterEnv: ["./afterSetup.js"],
|
|
||||||
clearMocks: true,
|
|
||||||
resetModules: true,
|
|
||||||
testTimeout: 10000
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = config;
|
|
||||||
5586
tests-ui/package-lock.json
generated
5586
tests-ui/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -1,31 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "comfui-tests",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "UI tests",
|
|
||||||
"main": "index.js",
|
|
||||||
"scripts": {
|
|
||||||
"test": "jest",
|
|
||||||
"test:generate": "node setup.js"
|
|
||||||
},
|
|
||||||
"repository": {
|
|
||||||
"type": "git",
|
|
||||||
"url": "git+https://github.com/comfyanonymous/ComfyUI.git"
|
|
||||||
},
|
|
||||||
"keywords": [
|
|
||||||
"comfyui",
|
|
||||||
"test"
|
|
||||||
],
|
|
||||||
"author": "comfyanonymous",
|
|
||||||
"license": "GPL-3.0",
|
|
||||||
"bugs": {
|
|
||||||
"url": "https://github.com/comfyanonymous/ComfyUI/issues"
|
|
||||||
},
|
|
||||||
"homepage": "https://github.com/comfyanonymous/ComfyUI#readme",
|
|
||||||
"devDependencies": {
|
|
||||||
"@babel/preset-env": "^7.22.20",
|
|
||||||
"@types/jest": "^29.5.5",
|
|
||||||
"babel-plugin-transform-import-meta": "^2.2.1",
|
|
||||||
"jest": "^29.7.0",
|
|
||||||
"jest-environment-jsdom": "^29.7.0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
const { spawn } = require("child_process");
|
|
||||||
const { resolve } = require("path");
|
|
||||||
const { existsSync, mkdirSync, writeFileSync } = require("fs");
|
|
||||||
const http = require("http");
|
|
||||||
|
|
||||||
async function setup() {
|
|
||||||
// Wait up to 30s for it to start
|
|
||||||
let success = false;
|
|
||||||
let child;
|
|
||||||
for (let i = 0; i < 30; i++) {
|
|
||||||
try {
|
|
||||||
await new Promise((res, rej) => {
|
|
||||||
http
|
|
||||||
.get("http://127.0.0.1:8188/object_info", (resp) => {
|
|
||||||
let data = "";
|
|
||||||
resp.on("data", (chunk) => {
|
|
||||||
data += chunk;
|
|
||||||
});
|
|
||||||
resp.on("end", () => {
|
|
||||||
// Modify the response data to add some checkpoints
|
|
||||||
const objectInfo = JSON.parse(data);
|
|
||||||
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
|
|
||||||
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
|
|
||||||
|
|
||||||
data = JSON.stringify(objectInfo, undefined, "\t");
|
|
||||||
|
|
||||||
const outDir = resolve("./data");
|
|
||||||
if (!existsSync(outDir)) {
|
|
||||||
mkdirSync(outDir);
|
|
||||||
}
|
|
||||||
|
|
||||||
const outPath = resolve(outDir, "object_info.json");
|
|
||||||
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
|
|
||||||
writeFileSync(outPath, data, {
|
|
||||||
encoding: "utf8",
|
|
||||||
});
|
|
||||||
res();
|
|
||||||
});
|
|
||||||
})
|
|
||||||
.on("error", rej);
|
|
||||||
});
|
|
||||||
success = true;
|
|
||||||
break;
|
|
||||||
} catch (error) {
|
|
||||||
console.log(i + "/30", error);
|
|
||||||
if (i === 0) {
|
|
||||||
// Start the server on first iteration if it fails to connect
|
|
||||||
console.log("Starting ComfyUI server...");
|
|
||||||
|
|
||||||
let python = resolve("../../python_embeded/python.exe");
|
|
||||||
let args;
|
|
||||||
let cwd;
|
|
||||||
if (existsSync(python)) {
|
|
||||||
args = ["-s", "ComfyUI/main.py"];
|
|
||||||
cwd = "../..";
|
|
||||||
} else {
|
|
||||||
python = "python";
|
|
||||||
args = ["main.py"];
|
|
||||||
cwd = "..";
|
|
||||||
}
|
|
||||||
args.push("--cpu");
|
|
||||||
console.log(python, ...args);
|
|
||||||
child = spawn(python, args, { cwd });
|
|
||||||
child.on("error", (err) => {
|
|
||||||
console.log(`Server error (${err})`);
|
|
||||||
i = 30;
|
|
||||||
});
|
|
||||||
child.on("exit", (code) => {
|
|
||||||
if (!success) {
|
|
||||||
console.log(`Server exited (${code})`);
|
|
||||||
i = 30;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
await new Promise((r) => {
|
|
||||||
setTimeout(r, 1000);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
child?.kill();
|
|
||||||
|
|
||||||
if (!success) {
|
|
||||||
throw new Error("Waiting for server failed...");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setup();
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
// @ts-check
|
|
||||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
|
||||||
const { start } = require("../utils");
|
|
||||||
const lg = require("../utils/litegraph");
|
|
||||||
|
|
||||||
describe("extensions", () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
lg.setup(global);
|
|
||||||
});
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
lg.teardown(global);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("calls each extension hook", async () => {
|
|
||||||
const mockExtension = {
|
|
||||||
name: "TestExtension",
|
|
||||||
init: jest.fn(),
|
|
||||||
setup: jest.fn(),
|
|
||||||
addCustomNodeDefs: jest.fn(),
|
|
||||||
getCustomWidgets: jest.fn(),
|
|
||||||
beforeRegisterNodeDef: jest.fn(),
|
|
||||||
registerCustomNodes: jest.fn(),
|
|
||||||
loadedGraphNode: jest.fn(),
|
|
||||||
nodeCreated: jest.fn(),
|
|
||||||
beforeConfigureGraph: jest.fn(),
|
|
||||||
afterConfigureGraph: jest.fn(),
|
|
||||||
};
|
|
||||||
|
|
||||||
const { app, ez, graph } = await start({
|
|
||||||
async preSetup(app) {
|
|
||||||
app.registerExtension(mockExtension);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Basic initialisation hooks should be called once, with app
|
|
||||||
expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.init).toHaveBeenCalledWith(app);
|
|
||||||
|
|
||||||
// Adding custom node defs should be passed the full list of nodes
|
|
||||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
|
|
||||||
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
|
|
||||||
expect(defs).toHaveProperty("KSampler");
|
|
||||||
expect(defs).toHaveProperty("LoadImage");
|
|
||||||
|
|
||||||
// Get custom widgets is called once and should return new widget types
|
|
||||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
|
|
||||||
|
|
||||||
// Before register node def will be called once per node type
|
|
||||||
const nodeNames = Object.keys(defs);
|
|
||||||
const nodeCount = nodeNames.length;
|
|
||||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
|
||||||
for (let i = 0; i < 10; i++) {
|
|
||||||
// It should be send the JS class and the original JSON definition
|
|
||||||
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
|
|
||||||
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
|
|
||||||
|
|
||||||
expect(nodeClass.name).toBe("ComfyNode");
|
|
||||||
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
|
|
||||||
expect(nodeDef.name).toBe(nodeNames[i]);
|
|
||||||
expect(nodeDef).toHaveProperty("input");
|
|
||||||
expect(nodeDef).toHaveProperty("output");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
|
|
||||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
|
||||||
|
|
||||||
// Before configure graph will be called here as the default graph is being loaded
|
|
||||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
|
|
||||||
// it gets sent the graph data that is going to be loaded
|
|
||||||
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
|
|
||||||
|
|
||||||
// A node created is fired for each node constructor that is called
|
|
||||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
|
|
||||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
|
||||||
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each node then calls loadedGraphNode to allow them to be updated
|
|
||||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
|
||||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
|
||||||
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// After configure is then called once all the setup is done
|
|
||||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
|
|
||||||
|
|
||||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.setup).toHaveBeenCalledWith(app);
|
|
||||||
|
|
||||||
// Ensure hooks are called in the correct order
|
|
||||||
const callOrder = [
|
|
||||||
"init",
|
|
||||||
"addCustomNodeDefs",
|
|
||||||
"getCustomWidgets",
|
|
||||||
"beforeRegisterNodeDef",
|
|
||||||
"registerCustomNodes",
|
|
||||||
"beforeConfigureGraph",
|
|
||||||
"nodeCreated",
|
|
||||||
"loadedGraphNode",
|
|
||||||
"afterConfigureGraph",
|
|
||||||
"setup",
|
|
||||||
];
|
|
||||||
for (let i = 1; i < callOrder.length; i++) {
|
|
||||||
const fn1 = mockExtension[callOrder[i - 1]];
|
|
||||||
const fn2 = mockExtension[callOrder[i]];
|
|
||||||
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.clear();
|
|
||||||
|
|
||||||
// Ensure adding a new node calls the correct callback
|
|
||||||
ez.LoadImage();
|
|
||||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
|
||||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
|
||||||
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
|
|
||||||
|
|
||||||
// Reload the graph to ensure correct hooks are fired
|
|
||||||
await graph.reload();
|
|
||||||
|
|
||||||
// These hooks should not be fired again
|
|
||||||
expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
|
||||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
|
||||||
|
|
||||||
// These should be called again
|
|
||||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
|
|
||||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
|
|
||||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
|
||||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
|
|
||||||
}, 15000);
|
|
||||||
|
|
||||||
it("allows custom nodeDefs and widgets to be registered", async () => {
|
|
||||||
const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
|
||||||
expect(node.constructor.comfyClass).toBe("TestNode");
|
|
||||||
expect(inputName).toBe("test_input");
|
|
||||||
expect(inputData[0]).toBe("CUSTOMWIDGET");
|
|
||||||
expect(inputData[1]?.hello).toBe("world");
|
|
||||||
expect(app).toStrictEqual(app);
|
|
||||||
|
|
||||||
return {
|
|
||||||
widget: node.addWidget("button", inputName, "hello", () => {}),
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
// Register our extension that adds a custom node + widget type
|
|
||||||
const mockExtension = {
|
|
||||||
name: "TestExtension",
|
|
||||||
addCustomNodeDefs: (nodeDefs) => {
|
|
||||||
nodeDefs["TestNode"] = {
|
|
||||||
output: [],
|
|
||||||
output_name: [],
|
|
||||||
output_is_list: [],
|
|
||||||
name: "TestNode",
|
|
||||||
display_name: "TestNode",
|
|
||||||
category: "Test",
|
|
||||||
input: {
|
|
||||||
required: {
|
|
||||||
test_input: ["CUSTOMWIDGET", { hello: "world" }],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
},
|
|
||||||
getCustomWidgets: jest.fn(() => {
|
|
||||||
return {
|
|
||||||
CUSTOMWIDGET: widgetMock,
|
|
||||||
};
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
const { graph, ez } = await start({
|
|
||||||
async preSetup(app) {
|
|
||||||
app.registerExtension(mockExtension);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
|
|
||||||
|
|
||||||
graph.clear();
|
|
||||||
expect(widgetMock).toBeCalledTimes(0);
|
|
||||||
const node = ez.TestNode();
|
|
||||||
expect(widgetMock).toBeCalledTimes(1);
|
|
||||||
|
|
||||||
// Ensure our custom widget is created
|
|
||||||
expect(node.inputs.length).toBe(0);
|
|
||||||
expect(node.widgets.length).toBe(1);
|
|
||||||
const w = node.widgets[0].widget;
|
|
||||||
expect(w.name).toBe("test_input");
|
|
||||||
expect(w.type).toBe("button");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,295 +0,0 @@
|
|||||||
// @ts-check
|
|
||||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
|
||||||
const { start } = require("../utils");
|
|
||||||
const lg = require("../utils/litegraph");
|
|
||||||
|
|
||||||
describe("users", () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
lg.setup(global);
|
|
||||||
});
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
lg.teardown(global);
|
|
||||||
});
|
|
||||||
|
|
||||||
function expectNoUserScreen() {
|
|
||||||
// Ensure login isnt visible
|
|
||||||
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
|
|
||||||
expect(selection["style"].display).toBe("none");
|
|
||||||
const menu = document.querySelectorAll(".comfy-menu")?.[0];
|
|
||||||
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
|
|
||||||
}
|
|
||||||
|
|
||||||
describe("multi-user", () => {
|
|
||||||
function mockAddStylesheet() {
|
|
||||||
const utils = require("../../web/scripts/utils");
|
|
||||||
utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve());
|
|
||||||
}
|
|
||||||
|
|
||||||
async function waitForUserScreenShow() {
|
|
||||||
mockAddStylesheet();
|
|
||||||
|
|
||||||
// Wait for "show" to be called
|
|
||||||
const { UserSelectionScreen } = require("../../web/scripts/ui/userSelection");
|
|
||||||
let resolve, reject;
|
|
||||||
const fn = UserSelectionScreen.prototype.show;
|
|
||||||
const p = new Promise((res, rej) => {
|
|
||||||
resolve = res;
|
|
||||||
reject = rej;
|
|
||||||
});
|
|
||||||
jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => {
|
|
||||||
const res = fn(...args);
|
|
||||||
await new Promise(process.nextTick); // wait for promises to resolve
|
|
||||||
resolve();
|
|
||||||
return res;
|
|
||||||
});
|
|
||||||
// @ts-ignore
|
|
||||||
setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500);
|
|
||||||
await p;
|
|
||||||
await new Promise(process.nextTick); // wait for promises to resolve
|
|
||||||
}
|
|
||||||
|
|
||||||
async function testUserScreen(onShown, users) {
|
|
||||||
if (!users) {
|
|
||||||
users = {};
|
|
||||||
}
|
|
||||||
const starting = start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: { storage: "server", users },
|
|
||||||
});
|
|
||||||
|
|
||||||
// Ensure no current user
|
|
||||||
expect(localStorage["Comfy.userId"]).toBeFalsy();
|
|
||||||
expect(localStorage["Comfy.userName"]).toBeFalsy();
|
|
||||||
|
|
||||||
await waitForUserScreenShow();
|
|
||||||
|
|
||||||
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
|
|
||||||
expect(selection).toBeTruthy();
|
|
||||||
|
|
||||||
// Ensure login is visible
|
|
||||||
expect(window.getComputedStyle(selection)?.display).not.toBe("none");
|
|
||||||
// Ensure menu is hidden
|
|
||||||
const menu = document.querySelectorAll(".comfy-menu")?.[0];
|
|
||||||
expect(window.getComputedStyle(menu)?.display).toBe("none");
|
|
||||||
|
|
||||||
const isCreate = await onShown(selection);
|
|
||||||
|
|
||||||
// Submit form
|
|
||||||
selection.querySelectorAll("form")[0].submit();
|
|
||||||
await new Promise(process.nextTick); // wait for promises to resolve
|
|
||||||
|
|
||||||
// Wait for start
|
|
||||||
const s = await starting;
|
|
||||||
|
|
||||||
// Ensure login is removed
|
|
||||||
expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0);
|
|
||||||
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
|
|
||||||
|
|
||||||
// Ensure settings + templates are saved
|
|
||||||
const { api } = require("../../web/scripts/api");
|
|
||||||
expect(api.createUser).toHaveBeenCalledTimes(+isCreate);
|
|
||||||
expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate);
|
|
||||||
expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate);
|
|
||||||
if (isCreate) {
|
|
||||||
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
|
|
||||||
expect(s.app.isNewUserSession).toBeTruthy();
|
|
||||||
} else {
|
|
||||||
expect(s.app.isNewUserSession).toBeFalsy();
|
|
||||||
}
|
|
||||||
|
|
||||||
return { users, selection, ...s };
|
|
||||||
}
|
|
||||||
|
|
||||||
it("allows user creation if no users", async () => {
|
|
||||||
const { users } = await testUserScreen((selection) => {
|
|
||||||
// Ensure we have no users flag added
|
|
||||||
expect(selection.classList.contains("no-users")).toBeTruthy();
|
|
||||||
|
|
||||||
// Enter a username
|
|
||||||
const input = selection.getElementsByTagName("input")[0];
|
|
||||||
input.focus();
|
|
||||||
input.value = "Test User";
|
|
||||||
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(users).toStrictEqual({
|
|
||||||
"Test User!": "Test User",
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(localStorage["Comfy.userId"]).toBe("Test User!");
|
|
||||||
expect(localStorage["Comfy.userName"]).toBe("Test User");
|
|
||||||
});
|
|
||||||
it("allows user creation if no current user but other users", async () => {
|
|
||||||
const users = {
|
|
||||||
"Test User 2!": "Test User 2",
|
|
||||||
};
|
|
||||||
|
|
||||||
await testUserScreen((selection) => {
|
|
||||||
expect(selection.classList.contains("no-users")).toBeFalsy();
|
|
||||||
|
|
||||||
// Enter a username
|
|
||||||
const input = selection.getElementsByTagName("input")[0];
|
|
||||||
input.focus();
|
|
||||||
input.value = "Test User 3";
|
|
||||||
return true;
|
|
||||||
}, users);
|
|
||||||
|
|
||||||
expect(users).toStrictEqual({
|
|
||||||
"Test User 2!": "Test User 2",
|
|
||||||
"Test User 3!": "Test User 3",
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(localStorage["Comfy.userId"]).toBe("Test User 3!");
|
|
||||||
expect(localStorage["Comfy.userName"]).toBe("Test User 3");
|
|
||||||
});
|
|
||||||
it("allows user selection if no current user but other users", async () => {
|
|
||||||
const users = {
|
|
||||||
"A!": "A",
|
|
||||||
"B!": "B",
|
|
||||||
"C!": "C",
|
|
||||||
};
|
|
||||||
|
|
||||||
await testUserScreen((selection) => {
|
|
||||||
expect(selection.classList.contains("no-users")).toBeFalsy();
|
|
||||||
|
|
||||||
// Check user list
|
|
||||||
const select = selection.getElementsByTagName("select")[0];
|
|
||||||
const options = select.getElementsByTagName("option");
|
|
||||||
expect(
|
|
||||||
[...options]
|
|
||||||
.filter((o) => !o.disabled)
|
|
||||||
.reduce((p, n) => {
|
|
||||||
p[n.getAttribute("value")] = n.textContent;
|
|
||||||
return p;
|
|
||||||
}, {})
|
|
||||||
).toStrictEqual(users);
|
|
||||||
|
|
||||||
// Select an option
|
|
||||||
select.focus();
|
|
||||||
select.value = options[2].value;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}, users);
|
|
||||||
|
|
||||||
expect(users).toStrictEqual(users);
|
|
||||||
|
|
||||||
expect(localStorage["Comfy.userId"]).toBe("B!");
|
|
||||||
expect(localStorage["Comfy.userName"]).toBe("B");
|
|
||||||
});
|
|
||||||
it("doesnt show user screen if current user", async () => {
|
|
||||||
const starting = start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: {
|
|
||||||
storage: "server",
|
|
||||||
users: {
|
|
||||||
"User!": "User",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
localStorage: {
|
|
||||||
"Comfy.userId": "User!",
|
|
||||||
"Comfy.userName": "User",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
await new Promise(process.nextTick); // wait for promises to resolve
|
|
||||||
|
|
||||||
expectNoUserScreen();
|
|
||||||
|
|
||||||
await starting;
|
|
||||||
});
|
|
||||||
it("allows user switching", async () => {
|
|
||||||
const { app } = await start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: {
|
|
||||||
storage: "server",
|
|
||||||
users: {
|
|
||||||
"User!": "User",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
localStorage: {
|
|
||||||
"Comfy.userId": "User!",
|
|
||||||
"Comfy.userName": "User",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// cant actually test switching user easily but can check the setting is present
|
|
||||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
describe("single-user", () => {
|
|
||||||
it("doesnt show user creation if no default user", async () => {
|
|
||||||
const { app } = await start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: { migrated: false, storage: "server" },
|
|
||||||
});
|
|
||||||
expectNoUserScreen();
|
|
||||||
|
|
||||||
// It should store the settings
|
|
||||||
const { api } = require("../../web/scripts/api");
|
|
||||||
expect(api.storeSettings).toHaveBeenCalledTimes(1);
|
|
||||||
expect(api.storeUserData).toHaveBeenCalledTimes(1);
|
|
||||||
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
|
|
||||||
expect(app.isNewUserSession).toBeTruthy();
|
|
||||||
});
|
|
||||||
it("doesnt show user creation if default user", async () => {
|
|
||||||
const { app } = await start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: { migrated: true, storage: "server" },
|
|
||||||
});
|
|
||||||
expectNoUserScreen();
|
|
||||||
|
|
||||||
// It should store the settings
|
|
||||||
const { api } = require("../../web/scripts/api");
|
|
||||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
|
||||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
|
||||||
expect(app.isNewUserSession).toBeFalsy();
|
|
||||||
});
|
|
||||||
it("doesnt allow user switching", async () => {
|
|
||||||
const { app } = await start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: { migrated: true, storage: "server" },
|
|
||||||
});
|
|
||||||
expectNoUserScreen();
|
|
||||||
|
|
||||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
describe("browser-user", () => {
|
|
||||||
it("doesnt show user creation if no default user", async () => {
|
|
||||||
const { app } = await start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: { migrated: false, storage: "browser" },
|
|
||||||
});
|
|
||||||
expectNoUserScreen();
|
|
||||||
|
|
||||||
// It should store the settings
|
|
||||||
const { api } = require("../../web/scripts/api");
|
|
||||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
|
||||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
|
||||||
expect(app.isNewUserSession).toBeFalsy();
|
|
||||||
});
|
|
||||||
it("doesnt show user creation if default user", async () => {
|
|
||||||
const { app } = await start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: { migrated: true, storage: "server" },
|
|
||||||
});
|
|
||||||
expectNoUserScreen();
|
|
||||||
|
|
||||||
// It should store the settings
|
|
||||||
const { api } = require("../../web/scripts/api");
|
|
||||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
|
||||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
|
||||||
expect(app.isNewUserSession).toBeFalsy();
|
|
||||||
});
|
|
||||||
it("doesnt allow user switching", async () => {
|
|
||||||
const { app } = await start({
|
|
||||||
resetEnv: true,
|
|
||||||
userConfig: { migrated: true, storage: "browser" },
|
|
||||||
});
|
|
||||||
expectNoUserScreen();
|
|
||||||
|
|
||||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
@@ -1,557 +0,0 @@
|
|||||||
// @ts-check
|
|
||||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
|
||||||
|
|
||||||
const {
|
|
||||||
start,
|
|
||||||
makeNodeDef,
|
|
||||||
checkBeforeAndAfterReload,
|
|
||||||
assertNotNullOrUndefined,
|
|
||||||
createDefaultWorkflow,
|
|
||||||
} = require("../utils");
|
|
||||||
const lg = require("../utils/litegraph");
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @typedef { import("../utils/ezgraph") } Ez
|
|
||||||
* @typedef { ReturnType<Ez["Ez"]["graph"]>["ez"] } EzNodeFactory
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { EzNodeFactory } ez
|
|
||||||
* @param { InstanceType<Ez["EzGraph"]> } graph
|
|
||||||
* @param { InstanceType<Ez["EzInput"]> } input
|
|
||||||
* @param { string } widgetType
|
|
||||||
* @param { number } controlWidgetCount
|
|
||||||
* @returns
|
|
||||||
*/
|
|
||||||
async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
|
|
||||||
// Connect to primitive and ensure its still connected after
|
|
||||||
let primitive = ez.PrimitiveNode();
|
|
||||||
primitive.outputs[0].connectTo(input);
|
|
||||||
|
|
||||||
await checkBeforeAndAfterReload(graph, async () => {
|
|
||||||
primitive = graph.find(primitive);
|
|
||||||
let { connections } = primitive.outputs[0];
|
|
||||||
expect(connections).toHaveLength(1);
|
|
||||||
expect(connections[0].targetNode.id).toBe(input.node.node.id);
|
|
||||||
|
|
||||||
// Ensure widget is correct type
|
|
||||||
const valueWidget = primitive.widgets.value;
|
|
||||||
expect(valueWidget.widget.type).toBe(widgetType);
|
|
||||||
|
|
||||||
// Check if control_after_generate should be added
|
|
||||||
if (controlWidgetCount) {
|
|
||||||
const controlWidget = primitive.widgets.control_after_generate;
|
|
||||||
expect(controlWidget.widget.type).toBe("combo");
|
|
||||||
if (widgetType === "combo") {
|
|
||||||
const filterWidget = primitive.widgets.control_filter_list;
|
|
||||||
expect(filterWidget.widget.type).toBe("string");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure we dont have other widgets
|
|
||||||
expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
|
|
||||||
});
|
|
||||||
|
|
||||||
return primitive;
|
|
||||||
}
|
|
||||||
|
|
||||||
describe("widget inputs", () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
lg.setup(global);
|
|
||||||
});
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
lg.teardown(global);
|
|
||||||
});
|
|
||||||
|
|
||||||
[
|
|
||||||
{ name: "int", type: "INT", widget: "number", control: 1 },
|
|
||||||
{ name: "float", type: "FLOAT", widget: "number", control: 1 },
|
|
||||||
{ name: "text", type: "STRING" },
|
|
||||||
{
|
|
||||||
name: "customtext",
|
|
||||||
type: "STRING",
|
|
||||||
opt: { multiline: true },
|
|
||||||
},
|
|
||||||
{ name: "toggle", type: "BOOLEAN" },
|
|
||||||
{ name: "combo", type: ["a", "b", "c"], control: 2 },
|
|
||||||
].forEach((c) => {
|
|
||||||
test(`widget conversion + primitive works on ${c.name}`, async () => {
|
|
||||||
const { ez, graph } = await start({
|
|
||||||
mockNodeDefs: makeNodeDef("TestNode", { [c.name]: [c.type, c.opt ?? {}] }),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Create test node and convert to input
|
|
||||||
const n = ez.TestNode();
|
|
||||||
const w = n.widgets[c.name];
|
|
||||||
w.convertToInput();
|
|
||||||
expect(w.isConvertedToInput).toBeTruthy();
|
|
||||||
const input = w.getConvertedInput();
|
|
||||||
expect(input).toBeTruthy();
|
|
||||||
|
|
||||||
// @ts-ignore : input is valid here
|
|
||||||
await connectPrimitiveAndReload(ez, graph, input, c.widget ?? c.name, c.control);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
test("converted widget works after reload", async () => {
|
|
||||||
const { ez, graph } = await start();
|
|
||||||
let n = ez.CheckpointLoaderSimple();
|
|
||||||
|
|
||||||
const inputCount = n.inputs.length;
|
|
||||||
|
|
||||||
// Convert ckpt name to an input
|
|
||||||
n.widgets.ckpt_name.convertToInput();
|
|
||||||
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
|
|
||||||
expect(n.inputs.ckpt_name).toBeTruthy();
|
|
||||||
expect(n.inputs.length).toEqual(inputCount + 1);
|
|
||||||
|
|
||||||
// Convert back to widget and ensure input is removed
|
|
||||||
n.widgets.ckpt_name.convertToWidget();
|
|
||||||
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
|
|
||||||
expect(n.inputs.ckpt_name).toBeFalsy();
|
|
||||||
expect(n.inputs.length).toEqual(inputCount);
|
|
||||||
|
|
||||||
// Convert again and reload the graph to ensure it maintains state
|
|
||||||
n.widgets.ckpt_name.convertToInput();
|
|
||||||
expect(n.inputs.length).toEqual(inputCount + 1);
|
|
||||||
|
|
||||||
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
|
|
||||||
|
|
||||||
// Disconnect & reconnect
|
|
||||||
primitive.outputs[0].connections[0].disconnect();
|
|
||||||
let { connections } = primitive.outputs[0];
|
|
||||||
expect(connections).toHaveLength(0);
|
|
||||||
|
|
||||||
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
|
|
||||||
({ connections } = primitive.outputs[0]);
|
|
||||||
expect(connections).toHaveLength(1);
|
|
||||||
expect(connections[0].targetNode.id).toBe(n.node.id);
|
|
||||||
|
|
||||||
// Convert back to widget and ensure input is removed
|
|
||||||
n.widgets.ckpt_name.convertToWidget();
|
|
||||||
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
|
|
||||||
expect(n.inputs.ckpt_name).toBeFalsy();
|
|
||||||
expect(n.inputs.length).toEqual(inputCount);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("converted widget works on clone", async () => {
|
|
||||||
const { graph, ez } = await start();
|
|
||||||
let n = ez.CheckpointLoaderSimple();
|
|
||||||
|
|
||||||
// Convert the widget to an input
|
|
||||||
n.widgets.ckpt_name.convertToInput();
|
|
||||||
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
|
|
||||||
|
|
||||||
// Clone the node
|
|
||||||
n.menu["Clone"].call();
|
|
||||||
expect(graph.nodes).toHaveLength(2);
|
|
||||||
const clone = graph.nodes[1];
|
|
||||||
expect(clone.id).not.toEqual(n.id);
|
|
||||||
|
|
||||||
// Ensure the clone has an input
|
|
||||||
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
|
|
||||||
expect(clone.inputs.ckpt_name).toBeTruthy();
|
|
||||||
|
|
||||||
// Ensure primitive connects to both nodes
|
|
||||||
let primitive = ez.PrimitiveNode();
|
|
||||||
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
|
|
||||||
primitive.outputs[0].connectTo(clone.inputs.ckpt_name);
|
|
||||||
expect(primitive.outputs[0].connections).toHaveLength(2);
|
|
||||||
|
|
||||||
// Convert back to widget and ensure input is removed
|
|
||||||
clone.widgets.ckpt_name.convertToWidget();
|
|
||||||
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
|
|
||||||
expect(clone.inputs.ckpt_name).toBeFalsy();
|
|
||||||
});
|
|
||||||
|
|
||||||
test("shows missing node error on custom node with converted input", async () => {
|
|
||||||
const { graph } = await start();
|
|
||||||
|
|
||||||
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
|
|
||||||
|
|
||||||
await graph.app.loadGraphData({
|
|
||||||
last_node_id: 3,
|
|
||||||
last_link_id: 4,
|
|
||||||
nodes: [
|
|
||||||
{
|
|
||||||
id: 1,
|
|
||||||
type: "TestNode",
|
|
||||||
pos: [41.87329101561909, 389.7381480823742],
|
|
||||||
size: { 0: 220, 1: 374 },
|
|
||||||
flags: {},
|
|
||||||
order: 1,
|
|
||||||
mode: 0,
|
|
||||||
inputs: [{ name: "test", type: "FLOAT", link: 4, widget: { name: "test" }, slot_index: 0 }],
|
|
||||||
outputs: [],
|
|
||||||
properties: { "Node name for S&R": "TestNode" },
|
|
||||||
widgets_values: [1],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 3,
|
|
||||||
type: "PrimitiveNode",
|
|
||||||
pos: [-312, 433],
|
|
||||||
size: { 0: 210, 1: 82 },
|
|
||||||
flags: {},
|
|
||||||
order: 0,
|
|
||||||
mode: 0,
|
|
||||||
outputs: [{ links: [4], widget: { name: "test" } }],
|
|
||||||
title: "test",
|
|
||||||
properties: {},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
links: [[4, 3, 0, 1, 6, "FLOAT"]],
|
|
||||||
groups: [],
|
|
||||||
config: {},
|
|
||||||
extra: {},
|
|
||||||
version: 0.4,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(dialogShow).toBeCalledTimes(1);
|
|
||||||
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found");
|
|
||||||
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("defaultInput widgets can be converted back to inputs", async () => {
|
|
||||||
const { graph, ez } = await start({
|
|
||||||
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { defaultInput: true }] }),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Create test node and ensure it starts as an input
|
|
||||||
let n = ez.TestNode();
|
|
||||||
let w = n.widgets.example;
|
|
||||||
expect(w.isConvertedToInput).toBeTruthy();
|
|
||||||
let input = w.getConvertedInput();
|
|
||||||
expect(input).toBeTruthy();
|
|
||||||
|
|
||||||
// Ensure it can be converted to
|
|
||||||
w.convertToWidget();
|
|
||||||
expect(w.isConvertedToInput).toBeFalsy();
|
|
||||||
expect(n.inputs.length).toEqual(0);
|
|
||||||
// and from
|
|
||||||
w.convertToInput();
|
|
||||||
expect(w.isConvertedToInput).toBeTruthy();
|
|
||||||
input = w.getConvertedInput();
|
|
||||||
|
|
||||||
// Reload and ensure it still only has 1 converted widget
|
|
||||||
if (!assertNotNullOrUndefined(input)) return;
|
|
||||||
|
|
||||||
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
|
|
||||||
n = graph.find(n);
|
|
||||||
expect(n.widgets).toHaveLength(1);
|
|
||||||
w = n.widgets.example;
|
|
||||||
expect(w.isConvertedToInput).toBeTruthy();
|
|
||||||
|
|
||||||
// Convert back to widget and ensure it is still a widget after reload
|
|
||||||
w.convertToWidget();
|
|
||||||
await graph.reload();
|
|
||||||
n = graph.find(n);
|
|
||||||
expect(n.widgets).toHaveLength(1);
|
|
||||||
expect(n.widgets[0].isConvertedToInput).toBeFalsy();
|
|
||||||
expect(n.inputs.length).toEqual(0);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("forceInput widgets can not be converted back to inputs", async () => {
|
|
||||||
const { graph, ez } = await start({
|
|
||||||
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { forceInput: true }] }),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Create test node and ensure it starts as an input
|
|
||||||
let n = ez.TestNode();
|
|
||||||
let w = n.widgets.example;
|
|
||||||
expect(w.isConvertedToInput).toBeTruthy();
|
|
||||||
const input = w.getConvertedInput();
|
|
||||||
expect(input).toBeTruthy();
|
|
||||||
|
|
||||||
// Convert to widget should error
|
|
||||||
expect(() => w.convertToWidget()).toThrow();
|
|
||||||
|
|
||||||
// Reload and ensure it still only has 1 converted widget
|
|
||||||
if (assertNotNullOrUndefined(input)) {
|
|
||||||
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
|
|
||||||
n = graph.find(n);
|
|
||||||
expect(n.widgets).toHaveLength(1);
|
|
||||||
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
test("primitive can connect to matching combos on converted widgets", async () => {
|
|
||||||
const { ez } = await start({
|
|
||||||
mockNodeDefs: {
|
|
||||||
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
|
|
||||||
...makeNodeDef("TestNode2", { example: [["A", "B", "C"], { forceInput: true }] }),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const n1 = ez.TestNode1();
|
|
||||||
const n2 = ez.TestNode2();
|
|
||||||
const p = ez.PrimitiveNode();
|
|
||||||
p.outputs[0].connectTo(n1.inputs[0]);
|
|
||||||
p.outputs[0].connectTo(n2.inputs[0]);
|
|
||||||
expect(p.outputs[0].connections).toHaveLength(2);
|
|
||||||
const valueWidget = p.widgets.value;
|
|
||||||
expect(valueWidget.widget.type).toBe("combo");
|
|
||||||
expect(valueWidget.widget.options.values).toEqual(["A", "B", "C"]);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("primitive can not connect to non matching combos on converted widgets", async () => {
|
|
||||||
const { ez } = await start({
|
|
||||||
mockNodeDefs: {
|
|
||||||
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
|
|
||||||
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const n1 = ez.TestNode1();
|
|
||||||
const n2 = ez.TestNode2();
|
|
||||||
const p = ez.PrimitiveNode();
|
|
||||||
p.outputs[0].connectTo(n1.inputs[0]);
|
|
||||||
expect(() => p.outputs[0].connectTo(n2.inputs[0])).toThrow();
|
|
||||||
expect(p.outputs[0].connections).toHaveLength(1);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("combo output can not connect to non matching combos list input", async () => {
|
|
||||||
const { ez } = await start({
|
|
||||||
mockNodeDefs: {
|
|
||||||
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
|
|
||||||
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
|
|
||||||
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const n1 = ez.TestNode1();
|
|
||||||
const n2 = ez.TestNode2();
|
|
||||||
const n3 = ez.TestNode3();
|
|
||||||
|
|
||||||
n1.outputs[0].connectTo(n2.inputs[0]);
|
|
||||||
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
test("combo primitive can filter list when control_after_generate called", async () => {
|
|
||||||
const { ez } = await start({
|
|
||||||
mockNodeDefs: {
|
|
||||||
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const n1 = ez.TestNode1();
|
|
||||||
n1.widgets.example.convertToInput();
|
|
||||||
const p = ez.PrimitiveNode();
|
|
||||||
p.outputs[0].connectTo(n1.inputs[0]);
|
|
||||||
|
|
||||||
const value = p.widgets.value;
|
|
||||||
const control = p.widgets.control_after_generate.widget;
|
|
||||||
const filter = p.widgets.control_filter_list;
|
|
||||||
|
|
||||||
expect(p.widgets.length).toBe(3);
|
|
||||||
control.value = "increment";
|
|
||||||
expect(value.value).toBe("A");
|
|
||||||
|
|
||||||
// Manually trigger after queue when set to increment
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("B");
|
|
||||||
|
|
||||||
// Filter to items containing D
|
|
||||||
filter.value = "D";
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("D");
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("DD");
|
|
||||||
|
|
||||||
// Check decrement
|
|
||||||
value.value = "BBB";
|
|
||||||
control.value = "decrement";
|
|
||||||
filter.value = "B";
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("BB");
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("B");
|
|
||||||
|
|
||||||
// Check regex works
|
|
||||||
value.value = "BBB";
|
|
||||||
filter.value = "/[AB]|^C$/";
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("AAA");
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("BB");
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("AA");
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("C");
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("B");
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("A");
|
|
||||||
|
|
||||||
// Check random
|
|
||||||
control.value = "randomize";
|
|
||||||
filter.value = "/D/";
|
|
||||||
for (let i = 0; i < 100; i++) {
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value === "D" || value.value === "DD").toBeTruthy();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure it doesnt apply when fixed
|
|
||||||
control.value = "fixed";
|
|
||||||
value.value = "B";
|
|
||||||
filter.value = "C";
|
|
||||||
control["afterQueued"]();
|
|
||||||
expect(value.value).toBe("B");
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("reroutes", () => {
|
|
||||||
async function checkOutput(graph, values) {
|
|
||||||
expect((await graph.toPrompt()).output).toStrictEqual({
|
|
||||||
1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
|
|
||||||
2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
|
|
||||||
3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
|
|
||||||
4: {
|
|
||||||
inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
|
|
||||||
class_type: "EmptyLatentImage",
|
|
||||||
},
|
|
||||||
5: {
|
|
||||||
inputs: {
|
|
||||||
seed: 0,
|
|
||||||
steps: 20,
|
|
||||||
cfg: 8,
|
|
||||||
sampler_name: "euler",
|
|
||||||
scheduler: values?.scheduler ?? "normal",
|
|
||||||
denoise: 1,
|
|
||||||
model: ["1", 0],
|
|
||||||
positive: ["2", 0],
|
|
||||||
negative: ["3", 0],
|
|
||||||
latent_image: ["4", 0],
|
|
||||||
},
|
|
||||||
class_type: "KSampler",
|
|
||||||
},
|
|
||||||
6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
|
|
||||||
7: {
|
|
||||||
inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
|
|
||||||
class_type: "SaveImage",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function waitForWidget(node) {
|
|
||||||
// widgets are created slightly after the graph is ready
|
|
||||||
// hard to find an exact hook to get these so just wait for them to be ready
|
|
||||||
for (let i = 0; i < 10; i++) {
|
|
||||||
await new Promise((r) => setTimeout(r, 10));
|
|
||||||
if (node.widgets?.value) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it("can connect primitive via a reroute path to a widget input", async () => {
|
|
||||||
const { ez, graph } = await start();
|
|
||||||
const nodes = createDefaultWorkflow(ez, graph);
|
|
||||||
|
|
||||||
nodes.empty.widgets.width.convertToInput();
|
|
||||||
nodes.sampler.widgets.scheduler.convertToInput();
|
|
||||||
nodes.save.widgets.filename_prefix.convertToInput();
|
|
||||||
|
|
||||||
let widthReroute = ez.Reroute();
|
|
||||||
let schedulerReroute = ez.Reroute();
|
|
||||||
let fileReroute = ez.Reroute();
|
|
||||||
|
|
||||||
let widthNext = widthReroute;
|
|
||||||
let schedulerNext = schedulerReroute;
|
|
||||||
let fileNext = fileReroute;
|
|
||||||
|
|
||||||
for (let i = 0; i < 5; i++) {
|
|
||||||
let next = ez.Reroute();
|
|
||||||
widthNext.outputs[0].connectTo(next.inputs[0]);
|
|
||||||
widthNext = next;
|
|
||||||
|
|
||||||
next = ez.Reroute();
|
|
||||||
schedulerNext.outputs[0].connectTo(next.inputs[0]);
|
|
||||||
schedulerNext = next;
|
|
||||||
|
|
||||||
next = ez.Reroute();
|
|
||||||
fileNext.outputs[0].connectTo(next.inputs[0]);
|
|
||||||
fileNext = next;
|
|
||||||
}
|
|
||||||
|
|
||||||
widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
|
|
||||||
schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
|
|
||||||
fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
|
|
||||||
|
|
||||||
let widthPrimitive = ez.PrimitiveNode();
|
|
||||||
let schedulerPrimitive = ez.PrimitiveNode();
|
|
||||||
let filePrimitive = ez.PrimitiveNode();
|
|
||||||
|
|
||||||
widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
|
|
||||||
schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
|
|
||||||
filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
|
|
||||||
expect(widthPrimitive.widgets.value.value).toBe(512);
|
|
||||||
widthPrimitive.widgets.value.value = 1024;
|
|
||||||
expect(schedulerPrimitive.widgets.value.value).toBe("normal");
|
|
||||||
schedulerPrimitive.widgets.value.value = "simple";
|
|
||||||
expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
|
|
||||||
filePrimitive.widgets.value.value = "ComfyTest";
|
|
||||||
|
|
||||||
await checkBeforeAndAfterReload(graph, async () => {
|
|
||||||
widthPrimitive = graph.find(widthPrimitive);
|
|
||||||
schedulerPrimitive = graph.find(schedulerPrimitive);
|
|
||||||
filePrimitive = graph.find(filePrimitive);
|
|
||||||
await waitForWidget(filePrimitive);
|
|
||||||
expect(widthPrimitive.widgets.length).toBe(2);
|
|
||||||
expect(schedulerPrimitive.widgets.length).toBe(3);
|
|
||||||
expect(filePrimitive.widgets.length).toBe(1);
|
|
||||||
|
|
||||||
await checkOutput(graph, {
|
|
||||||
width: 1024,
|
|
||||||
scheduler: "simple",
|
|
||||||
filename_prefix: "ComfyTest",
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
it("can connect primitive via a reroute path to multiple widget inputs", async () => {
|
|
||||||
const { ez, graph } = await start();
|
|
||||||
const nodes = createDefaultWorkflow(ez, graph);
|
|
||||||
|
|
||||||
nodes.empty.widgets.width.convertToInput();
|
|
||||||
nodes.empty.widgets.height.convertToInput();
|
|
||||||
nodes.empty.widgets.batch_size.convertToInput();
|
|
||||||
|
|
||||||
let reroute = ez.Reroute();
|
|
||||||
let prevReroute = reroute;
|
|
||||||
for (let i = 0; i < 5; i++) {
|
|
||||||
const next = ez.Reroute();
|
|
||||||
prevReroute.outputs[0].connectTo(next.inputs[0]);
|
|
||||||
prevReroute = next;
|
|
||||||
}
|
|
||||||
|
|
||||||
const r1 = ez.Reroute(prevReroute.outputs[0]);
|
|
||||||
const r2 = ez.Reroute(prevReroute.outputs[0]);
|
|
||||||
const r3 = ez.Reroute(r2.outputs[0]);
|
|
||||||
const r4 = ez.Reroute(r2.outputs[0]);
|
|
||||||
|
|
||||||
r1.outputs[0].connectTo(nodes.empty.inputs.width);
|
|
||||||
r3.outputs[0].connectTo(nodes.empty.inputs.height);
|
|
||||||
r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
|
|
||||||
|
|
||||||
let primitive = ez.PrimitiveNode();
|
|
||||||
primitive.outputs[0].connectTo(reroute.inputs[0]);
|
|
||||||
expect(primitive.widgets.value.value).toBe(1);
|
|
||||||
primitive.widgets.value.value = 64;
|
|
||||||
|
|
||||||
await checkBeforeAndAfterReload(graph, async (r) => {
|
|
||||||
primitive = graph.find(primitive);
|
|
||||||
await waitForWidget(primitive);
|
|
||||||
|
|
||||||
// Ensure widget configs are merged
|
|
||||||
expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
|
|
||||||
expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
|
|
||||||
expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
|
|
||||||
|
|
||||||
await checkOutput(graph, {
|
|
||||||
width: 64,
|
|
||||||
height: 64,
|
|
||||||
batch_size: 64,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
@@ -1,452 +0,0 @@
|
|||||||
// @ts-check
|
|
||||||
/// <reference path="../../web/types/litegraph.d.ts" />
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @typedef { import("../../web/scripts/app")["app"] } app
|
|
||||||
* @typedef { import("../../web/types/litegraph") } LG
|
|
||||||
* @typedef { import("../../web/types/litegraph").IWidget } IWidget
|
|
||||||
* @typedef { import("../../web/types/litegraph").ContextMenuItem } ContextMenuItem
|
|
||||||
* @typedef { import("../../web/types/litegraph").INodeInputSlot } INodeInputSlot
|
|
||||||
* @typedef { import("../../web/types/litegraph").INodeOutputSlot } INodeOutputSlot
|
|
||||||
* @typedef { InstanceType<LG["LGraphNode"]> & { widgets?: Array<IWidget> } } LGNode
|
|
||||||
* @typedef { (...args: EzOutput[] | [...EzOutput[], Record<string, unknown>]) => EzNode } EzNodeFactory
|
|
||||||
*/
|
|
||||||
|
|
||||||
export class EzConnection {
|
|
||||||
/** @type { app } */
|
|
||||||
app;
|
|
||||||
/** @type { InstanceType<LG["LLink"]> } */
|
|
||||||
link;
|
|
||||||
|
|
||||||
get originNode() {
|
|
||||||
return new EzNode(this.app, this.app.graph.getNodeById(this.link.origin_id));
|
|
||||||
}
|
|
||||||
|
|
||||||
get originOutput() {
|
|
||||||
return this.originNode.outputs[this.link.origin_slot];
|
|
||||||
}
|
|
||||||
|
|
||||||
get targetNode() {
|
|
||||||
return new EzNode(this.app, this.app.graph.getNodeById(this.link.target_id));
|
|
||||||
}
|
|
||||||
|
|
||||||
get targetInput() {
|
|
||||||
return this.targetNode.inputs[this.link.target_slot];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { app } app
|
|
||||||
* @param { InstanceType<LG["LLink"]> } link
|
|
||||||
*/
|
|
||||||
constructor(app, link) {
|
|
||||||
this.app = app;
|
|
||||||
this.link = link;
|
|
||||||
}
|
|
||||||
|
|
||||||
disconnect() {
|
|
||||||
this.targetInput.disconnect();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class EzSlot {
|
|
||||||
/** @type { EzNode } */
|
|
||||||
node;
|
|
||||||
/** @type { number } */
|
|
||||||
index;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { EzNode } node
|
|
||||||
* @param { number } index
|
|
||||||
*/
|
|
||||||
constructor(node, index) {
|
|
||||||
this.node = node;
|
|
||||||
this.index = index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class EzInput extends EzSlot {
|
|
||||||
/** @type { INodeInputSlot } */
|
|
||||||
input;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { EzNode } node
|
|
||||||
* @param { number } index
|
|
||||||
* @param { INodeInputSlot } input
|
|
||||||
*/
|
|
||||||
constructor(node, index, input) {
|
|
||||||
super(node, index);
|
|
||||||
this.input = input;
|
|
||||||
}
|
|
||||||
|
|
||||||
get connection() {
|
|
||||||
const link = this.node.node.inputs?.[this.index]?.link;
|
|
||||||
if (link == null) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
|
|
||||||
}
|
|
||||||
|
|
||||||
disconnect() {
|
|
||||||
this.node.node.disconnectInput(this.index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class EzOutput extends EzSlot {
|
|
||||||
/** @type { INodeOutputSlot } */
|
|
||||||
output;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { EzNode } node
|
|
||||||
* @param { number } index
|
|
||||||
* @param { INodeOutputSlot } output
|
|
||||||
*/
|
|
||||||
constructor(node, index, output) {
|
|
||||||
super(node, index);
|
|
||||||
this.output = output;
|
|
||||||
}
|
|
||||||
|
|
||||||
get connections() {
|
|
||||||
return (this.node.node.outputs?.[this.index]?.links ?? []).map(
|
|
||||||
(l) => new EzConnection(this.node.app, this.node.app.graph.links[l])
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { EzInput } input
|
|
||||||
*/
|
|
||||||
connectTo(input) {
|
|
||||||
if (!input) throw new Error("Invalid input");
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @type { LG["LLink"] | null }
|
|
||||||
*/
|
|
||||||
const link = this.node.node.connect(this.index, input.node.node, input.index);
|
|
||||||
if (!link) {
|
|
||||||
const inp = input.input;
|
|
||||||
const inName = inp.name || inp.label || inp.type;
|
|
||||||
throw new Error(
|
|
||||||
`Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
|
|
||||||
this.output.name ?? this.output.type
|
|
||||||
}#${this.index}] failed.`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return link;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class EzNodeMenuItem {
|
|
||||||
/** @type { EzNode } */
|
|
||||||
node;
|
|
||||||
/** @type { number } */
|
|
||||||
index;
|
|
||||||
/** @type { ContextMenuItem } */
|
|
||||||
item;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { EzNode } node
|
|
||||||
* @param { number } index
|
|
||||||
* @param { ContextMenuItem } item
|
|
||||||
*/
|
|
||||||
constructor(node, index, item) {
|
|
||||||
this.node = node;
|
|
||||||
this.index = index;
|
|
||||||
this.item = item;
|
|
||||||
}
|
|
||||||
|
|
||||||
call(selectNode = true) {
|
|
||||||
if (!this.item?.callback) throw new Error(`Menu Item ${this.item?.content ?? "[null]"} has no callback.`);
|
|
||||||
if (selectNode) {
|
|
||||||
this.node.select();
|
|
||||||
}
|
|
||||||
return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class EzWidget {
|
|
||||||
/** @type { EzNode } */
|
|
||||||
node;
|
|
||||||
/** @type { number } */
|
|
||||||
index;
|
|
||||||
/** @type { IWidget } */
|
|
||||||
widget;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { EzNode } node
|
|
||||||
* @param { number } index
|
|
||||||
* @param { IWidget } widget
|
|
||||||
*/
|
|
||||||
constructor(node, index, widget) {
|
|
||||||
this.node = node;
|
|
||||||
this.index = index;
|
|
||||||
this.widget = widget;
|
|
||||||
}
|
|
||||||
|
|
||||||
get value() {
|
|
||||||
return this.widget.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
set value(v) {
|
|
||||||
this.widget.value = v;
|
|
||||||
this.widget.callback?.call?.(this.widget, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
get isConvertedToInput() {
|
|
||||||
// @ts-ignore : this type is valid for converted widgets
|
|
||||||
return this.widget.type === "converted-widget";
|
|
||||||
}
|
|
||||||
|
|
||||||
getConvertedInput() {
|
|
||||||
if (!this.isConvertedToInput) throw new Error(`Widget ${this.widget.name} is not converted to input.`);
|
|
||||||
|
|
||||||
return this.node.inputs.find((inp) => inp.input["widget"]?.name === this.widget.name);
|
|
||||||
}
|
|
||||||
|
|
||||||
convertToWidget() {
|
|
||||||
if (!this.isConvertedToInput)
|
|
||||||
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`);
|
|
||||||
var menu = this.node.menu["Convert Input to Widget"].item.submenu.options;
|
|
||||||
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to widget`);
|
|
||||||
menu[index].callback.call();
|
|
||||||
}
|
|
||||||
|
|
||||||
convertToInput() {
|
|
||||||
if (this.isConvertedToInput)
|
|
||||||
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`);
|
|
||||||
var menu = this.node.menu["Convert Widget to Input"].item.submenu.options;
|
|
||||||
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to input`);
|
|
||||||
menu[index].callback.call();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class EzNode {
|
|
||||||
/** @type { app } */
|
|
||||||
app;
|
|
||||||
/** @type { LGNode } */
|
|
||||||
node;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { app } app
|
|
||||||
* @param { LGNode } node
|
|
||||||
*/
|
|
||||||
constructor(app, node) {
|
|
||||||
this.app = app;
|
|
||||||
this.node = node;
|
|
||||||
}
|
|
||||||
|
|
||||||
get id() {
|
|
||||||
return this.node.id;
|
|
||||||
}
|
|
||||||
|
|
||||||
get inputs() {
|
|
||||||
return this.#makeLookupArray("inputs", "name", EzInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
get outputs() {
|
|
||||||
return this.#makeLookupArray("outputs", "name", EzOutput);
|
|
||||||
}
|
|
||||||
|
|
||||||
get widgets() {
|
|
||||||
return this.#makeLookupArray("widgets", "name", EzWidget);
|
|
||||||
}
|
|
||||||
|
|
||||||
get menu() {
|
|
||||||
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
|
|
||||||
}
|
|
||||||
|
|
||||||
get isRemoved() {
|
|
||||||
return !this.app.graph.getNodeById(this.id);
|
|
||||||
}
|
|
||||||
|
|
||||||
select(addToSelection = false) {
|
|
||||||
this.app.canvas.selectNode(this.node, addToSelection);
|
|
||||||
}
|
|
||||||
|
|
||||||
// /**
|
|
||||||
// * @template { "inputs" | "outputs" } T
|
|
||||||
// * @param { T } type
|
|
||||||
// * @returns { Record<string, type extends "inputs" ? EzInput : EzOutput> & (type extends "inputs" ? EzInput [] : EzOutput[]) }
|
|
||||||
// */
|
|
||||||
// #getSlotItems(type) {
|
|
||||||
// // @ts-ignore : these items are correct
|
|
||||||
// return (this.node[type] ?? []).reduce((p, s, i) => {
|
|
||||||
// if (s.name in p) {
|
|
||||||
// throw new Error(`Unable to store input ${s.name} on array as name conflicts.`);
|
|
||||||
// }
|
|
||||||
// // @ts-ignore
|
|
||||||
// p.push((p[s.name] = new (type === "inputs" ? EzInput : EzOutput)(this, i, s)));
|
|
||||||
// return p;
|
|
||||||
// }, Object.assign([], { $: this }));
|
|
||||||
// }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @template { { new(node: EzNode, index: number, obj: any): any } } T
|
|
||||||
* @param { "inputs" | "outputs" | "widgets" | (() => Array<unknown>) } nodeProperty
|
|
||||||
* @param { string } nameProperty
|
|
||||||
* @param { T } ctor
|
|
||||||
* @returns { Record<string, InstanceType<T>> & Array<InstanceType<T>> }
|
|
||||||
*/
|
|
||||||
#makeLookupArray(nodeProperty, nameProperty, ctor) {
|
|
||||||
const items = typeof nodeProperty === "function" ? nodeProperty() : this.node[nodeProperty];
|
|
||||||
// @ts-ignore
|
|
||||||
return (items ?? []).reduce((p, s, i) => {
|
|
||||||
if (!s) return p;
|
|
||||||
|
|
||||||
const name = s[nameProperty];
|
|
||||||
const item = new ctor(this, i, s);
|
|
||||||
// @ts-ignore
|
|
||||||
p.push(item);
|
|
||||||
if (name) {
|
|
||||||
// @ts-ignore
|
|
||||||
if (name in p) {
|
|
||||||
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// @ts-ignore
|
|
||||||
p[name] = item;
|
|
||||||
return p;
|
|
||||||
}, Object.assign([], { $: this }));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class EzGraph {
|
|
||||||
/** @type { app } */
|
|
||||||
app;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { app } app
|
|
||||||
*/
|
|
||||||
constructor(app) {
|
|
||||||
this.app = app;
|
|
||||||
}
|
|
||||||
|
|
||||||
get nodes() {
|
|
||||||
return this.app.graph._nodes.map((n) => new EzNode(this.app, n));
|
|
||||||
}
|
|
||||||
|
|
||||||
clear() {
|
|
||||||
this.app.graph.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
arrange() {
|
|
||||||
this.app.graph.arrange();
|
|
||||||
}
|
|
||||||
|
|
||||||
stringify() {
|
|
||||||
return JSON.stringify(this.app.graph.serialize(), undefined);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { number | LGNode | EzNode } obj
|
|
||||||
* @returns { EzNode }
|
|
||||||
*/
|
|
||||||
find(obj) {
|
|
||||||
let match;
|
|
||||||
let id;
|
|
||||||
if (typeof obj === "number") {
|
|
||||||
id = obj;
|
|
||||||
} else {
|
|
||||||
id = obj.id;
|
|
||||||
}
|
|
||||||
|
|
||||||
match = this.app.graph.getNodeById(id);
|
|
||||||
|
|
||||||
if (!match) {
|
|
||||||
throw new Error(`Unable to find node with ID ${id}.`);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new EzNode(this.app, match);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @returns { Promise<void> }
|
|
||||||
*/
|
|
||||||
reload() {
|
|
||||||
const graph = JSON.parse(JSON.stringify(this.app.graph.serialize()));
|
|
||||||
return new Promise((r) => {
|
|
||||||
this.app.graph.clear();
|
|
||||||
setTimeout(async () => {
|
|
||||||
await this.app.loadGraphData(graph);
|
|
||||||
r();
|
|
||||||
}, 10);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @returns { Promise<{
|
|
||||||
* workflow: {},
|
|
||||||
* output: Record<string, {
|
|
||||||
* class_name: string,
|
|
||||||
* inputs: Record<string, [string, number] | unknown>
|
|
||||||
* }>}> }
|
|
||||||
*/
|
|
||||||
toPrompt() {
|
|
||||||
// @ts-ignore
|
|
||||||
return this.app.graphToPrompt();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const Ez = {
|
|
||||||
/**
|
|
||||||
* Quickly build and interact with a ComfyUI graph
|
|
||||||
* @example
|
|
||||||
* const { ez, graph } = Ez.graph(app);
|
|
||||||
* graph.clear();
|
|
||||||
* const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
|
|
||||||
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
|
|
||||||
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
|
|
||||||
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
|
|
||||||
* const [image] = ez.VAEDecode(latent, vae).outputs;
|
|
||||||
* const saveNode = ez.SaveImage(image);
|
|
||||||
* console.log(saveNode);
|
|
||||||
* graph.arrange();
|
|
||||||
* @param { app } app
|
|
||||||
* @param { LG["LiteGraph"] } LiteGraph
|
|
||||||
* @param { LG["LGraphCanvas"] } LGraphCanvas
|
|
||||||
* @param { boolean } clearGraph
|
|
||||||
* @returns { { graph: EzGraph, ez: Record<string, EzNodeFactory> } }
|
|
||||||
*/
|
|
||||||
graph(app, LiteGraph = window["LiteGraph"], LGraphCanvas = window["LGraphCanvas"], clearGraph = true) {
|
|
||||||
// Always set the active canvas so things work
|
|
||||||
LGraphCanvas.active_canvas = app.canvas;
|
|
||||||
|
|
||||||
if (clearGraph) {
|
|
||||||
app.graph.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
// @ts-ignore : this proxy handles utility methods & node creation
|
|
||||||
const factory = new Proxy(
|
|
||||||
{},
|
|
||||||
{
|
|
||||||
get(_, p) {
|
|
||||||
if (typeof p !== "string") throw new Error("Invalid node");
|
|
||||||
const node = LiteGraph.createNode(p);
|
|
||||||
if (!node) throw new Error(`Unknown node "${p}"`);
|
|
||||||
app.graph.add(node);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param {Parameters<EzNodeFactory>} args
|
|
||||||
*/
|
|
||||||
return function (...args) {
|
|
||||||
const ezNode = new EzNode(app, node);
|
|
||||||
const inputs = ezNode.inputs;
|
|
||||||
|
|
||||||
let slot = 0;
|
|
||||||
for (const arg of args) {
|
|
||||||
if (arg instanceof EzOutput) {
|
|
||||||
arg.connectTo(inputs[slot++]);
|
|
||||||
} else {
|
|
||||||
for (const k in arg) {
|
|
||||||
ezNode.widgets[k].value = arg[k];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ezNode;
|
|
||||||
};
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
return { graph: new EzGraph(app), ez: factory };
|
|
||||||
},
|
|
||||||
};
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
const { mockApi } = require("./setup");
|
|
||||||
const { Ez } = require("./ezgraph");
|
|
||||||
const lg = require("./litegraph");
|
|
||||||
const fs = require("fs");
|
|
||||||
const path = require("path");
|
|
||||||
|
|
||||||
const html = fs.readFileSync(path.resolve(__dirname, "../../web/index.html"))
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param { Parameters<typeof mockApi>[0] & {
|
|
||||||
* resetEnv?: boolean,
|
|
||||||
* preSetup?(app): Promise<void>,
|
|
||||||
* localStorage?: Record<string, string>
|
|
||||||
* } } config
|
|
||||||
* @returns
|
|
||||||
*/
|
|
||||||
export async function start(config = {}) {
|
|
||||||
if(config.resetEnv) {
|
|
||||||
jest.resetModules();
|
|
||||||
jest.resetAllMocks();
|
|
||||||
lg.setup(global);
|
|
||||||
localStorage.clear();
|
|
||||||
sessionStorage.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
Object.assign(localStorage, config.localStorage ?? {});
|
|
||||||
document.body.innerHTML = html;
|
|
||||||
|
|
||||||
mockApi(config);
|
|
||||||
const { app } = require("../../web/scripts/app");
|
|
||||||
config.preSetup?.(app);
|
|
||||||
await app.setup();
|
|
||||||
|
|
||||||
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { ReturnType<Ez["graph"]>["graph"] } graph
|
|
||||||
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
|
|
||||||
*/
|
|
||||||
export async function checkBeforeAndAfterReload(graph, cb) {
|
|
||||||
await cb(false);
|
|
||||||
await graph.reload();
|
|
||||||
await cb(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param { string } name
|
|
||||||
* @param { Record<string, string | [string | string[], any]> } input
|
|
||||||
* @param { (string | string[])[] | Record<string, string | string[]> } output
|
|
||||||
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
|
|
||||||
*/
|
|
||||||
export function makeNodeDef(name, input, output = {}) {
|
|
||||||
const nodeDef = {
|
|
||||||
name,
|
|
||||||
category: "test",
|
|
||||||
output: [],
|
|
||||||
output_name: [],
|
|
||||||
output_is_list: [],
|
|
||||||
input: {
|
|
||||||
required: {},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
for (const k in input) {
|
|
||||||
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
|
|
||||||
}
|
|
||||||
if (output instanceof Array) {
|
|
||||||
output = output.reduce((p, c) => {
|
|
||||||
p[c] = c;
|
|
||||||
return p;
|
|
||||||
}, {});
|
|
||||||
}
|
|
||||||
for (const k in output) {
|
|
||||||
nodeDef.output.push(output[k]);
|
|
||||||
nodeDef.output_name.push(k);
|
|
||||||
nodeDef.output_is_list.push(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
return { [name]: nodeDef };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
/**
|
|
||||||
* @template { any } T
|
|
||||||
* @param { T } x
|
|
||||||
* @returns { x is Exclude<T, null | undefined> }
|
|
||||||
*/
|
|
||||||
export function assertNotNullOrUndefined(x) {
|
|
||||||
expect(x).not.toEqual(null);
|
|
||||||
expect(x).not.toEqual(undefined);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param { ReturnType<Ez["graph"]>["ez"] } ez
|
|
||||||
* @param { ReturnType<Ez["graph"]>["graph"] } graph
|
|
||||||
*/
|
|
||||||
export function createDefaultWorkflow(ez, graph) {
|
|
||||||
graph.clear();
|
|
||||||
const ckpt = ez.CheckpointLoaderSimple();
|
|
||||||
|
|
||||||
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
|
|
||||||
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
|
|
||||||
|
|
||||||
const empty = ez.EmptyLatentImage();
|
|
||||||
const sampler = ez.KSampler(
|
|
||||||
ckpt.outputs.MODEL,
|
|
||||||
pos.outputs.CONDITIONING,
|
|
||||||
neg.outputs.CONDITIONING,
|
|
||||||
empty.outputs.LATENT
|
|
||||||
);
|
|
||||||
|
|
||||||
const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
|
|
||||||
const save = ez.SaveImage(decode.outputs.IMAGE);
|
|
||||||
graph.arrange();
|
|
||||||
|
|
||||||
return { ckpt, pos, neg, empty, sampler, decode, save };
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function getNodeDefs() {
|
|
||||||
const { api } = require("../../web/scripts/api");
|
|
||||||
return api.getNodeDefs();
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function getNodeDef(nodeId) {
|
|
||||||
return (await getNodeDefs())[nodeId];
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
const fs = require("fs");
|
|
||||||
const path = require("path");
|
|
||||||
const { nop } = require("../utils/nopProxy");
|
|
||||||
|
|
||||||
function forEachKey(cb) {
|
|
||||||
for (const k of [
|
|
||||||
"LiteGraph",
|
|
||||||
"LGraph",
|
|
||||||
"LLink",
|
|
||||||
"LGraphNode",
|
|
||||||
"LGraphGroup",
|
|
||||||
"DragAndScale",
|
|
||||||
"LGraphCanvas",
|
|
||||||
"ContextMenu",
|
|
||||||
]) {
|
|
||||||
cb(k);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function setup(ctx) {
|
|
||||||
const lg = fs.readFileSync(path.resolve("../web/lib/litegraph.core.js"), "utf-8");
|
|
||||||
const globalTemp = {};
|
|
||||||
(function (console) {
|
|
||||||
eval(lg);
|
|
||||||
}).call(globalTemp, nop);
|
|
||||||
|
|
||||||
forEachKey((k) => (ctx[k] = globalTemp[k]));
|
|
||||||
require(path.resolve("../web/lib/litegraph.extensions.js"));
|
|
||||||
}
|
|
||||||
|
|
||||||
export function teardown(ctx) {
|
|
||||||
forEachKey((k) => delete ctx[k]);
|
|
||||||
|
|
||||||
// Clear document after each run
|
|
||||||
document.getElementsByTagName("html")[0].innerHTML = "";
|
|
||||||
}
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
export const nop = new Proxy(function () {}, {
|
|
||||||
get: () => nop,
|
|
||||||
set: () => true,
|
|
||||||
apply: () => nop,
|
|
||||||
construct: () => nop,
|
|
||||||
});
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
require("../../web/scripts/api");
|
|
||||||
|
|
||||||
const fs = require("fs");
|
|
||||||
const path = require("path");
|
|
||||||
function* walkSync(dir) {
|
|
||||||
const files = fs.readdirSync(dir, { withFileTypes: true });
|
|
||||||
for (const file of files) {
|
|
||||||
if (file.isDirectory()) {
|
|
||||||
yield* walkSync(path.join(dir, file.name));
|
|
||||||
} else {
|
|
||||||
yield path.join(dir, file.name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @typedef { import("../../web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param {{
|
|
||||||
* mockExtensions?: string[],
|
|
||||||
* mockNodeDefs?: Record<string, ComfyObjectInfo>,
|
|
||||||
* settings?: Record<string, string>
|
|
||||||
* userConfig?: {storage: "server" | "browser", users?: Record<string, any>, migrated?: boolean },
|
|
||||||
* userData?: Record<string, any>
|
|
||||||
* }} config
|
|
||||||
*/
|
|
||||||
export function mockApi(config = {}) {
|
|
||||||
let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = {
|
|
||||||
userConfig,
|
|
||||||
settings: {},
|
|
||||||
userData: {},
|
|
||||||
...config,
|
|
||||||
};
|
|
||||||
if (!mockExtensions) {
|
|
||||||
mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core")))
|
|
||||||
.filter((x) => x.endsWith(".js"))
|
|
||||||
.map((x) => path.relative(path.resolve("../web"), x));
|
|
||||||
}
|
|
||||||
if (!mockNodeDefs) {
|
|
||||||
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
|
|
||||||
}
|
|
||||||
|
|
||||||
const events = new EventTarget();
|
|
||||||
const mockApi = {
|
|
||||||
addEventListener: events.addEventListener.bind(events),
|
|
||||||
removeEventListener: events.removeEventListener.bind(events),
|
|
||||||
dispatchEvent: events.dispatchEvent.bind(events),
|
|
||||||
getSystemStats: jest.fn(),
|
|
||||||
getExtensions: jest.fn(() => mockExtensions),
|
|
||||||
getNodeDefs: jest.fn(() => mockNodeDefs),
|
|
||||||
init: jest.fn(),
|
|
||||||
apiURL: jest.fn((x) => "../../web/" + x),
|
|
||||||
createUser: jest.fn((username) => {
|
|
||||||
if(username in userConfig.users) {
|
|
||||||
return { status: 400, json: () => "Duplicate" }
|
|
||||||
}
|
|
||||||
userConfig.users[username + "!"] = username;
|
|
||||||
return { status: 200, json: () => username + "!" }
|
|
||||||
}),
|
|
||||||
getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }),
|
|
||||||
getSettings: jest.fn(() => settings),
|
|
||||||
storeSettings: jest.fn((v) => Object.assign(settings, v)),
|
|
||||||
getUserData: jest.fn((f) => {
|
|
||||||
if (f in userData) {
|
|
||||||
return { status: 200, json: () => userData[f] };
|
|
||||||
} else {
|
|
||||||
return { status: 404 };
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
storeUserData: jest.fn((file, data) => {
|
|
||||||
userData[file] = data;
|
|
||||||
}),
|
|
||||||
listUserData: jest.fn(() => [])
|
|
||||||
};
|
|
||||||
jest.mock("../../web/scripts/api", () => ({
|
|
||||||
get api() {
|
|
||||||
return mockApi;
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import pytest
|
import pytest
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from app.frontend_management import (
|
from app.frontend_management import (
|
||||||
FrontendManager,
|
FrontendManager,
|
||||||
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
|
|||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
FrontendManager.init_frontend_unsafe(version_string)
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_os_functions():
|
||||||
|
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||||
|
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||||
|
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||||
|
mock_listdir.return_value = [] # Simulate empty directory
|
||||||
|
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_download():
|
||||||
|
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||||
|
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||||
|
# Arrange
|
||||||
|
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||||
|
version_string = 'test-owner/test-repo@1.0.0'
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_makedirs.assert_called_once()
|
||||||
|
mock_download.assert_called_once()
|
||||||
|
mock_listdir.assert_called_once()
|
||||||
|
mock_rmdir.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_parse_version_string():
|
def test_parse_version_string():
|
||||||
version_string = "owner/repo@1.0.0"
|
version_string = "owner/repo@1.0.0"
|
||||||
|
|||||||
0
tests-unit/prompt_server_test/__init__.py
Normal file
0
tests-unit/prompt_server_test/__init__.py
Normal file
321
tests-unit/prompt_server_test/download_models_test.py
Normal file
321
tests-unit/prompt_server_test/download_models_test.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
import pytest
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import ClientResponse
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||||
|
|
||||||
|
class AsyncIteratorMock:
|
||||||
|
"""
|
||||||
|
A mock class that simulates an asynchronous iterator.
|
||||||
|
This is used to mimic the behavior of aiohttp's content iterator.
|
||||||
|
"""
|
||||||
|
def __init__(self, seq):
|
||||||
|
# Convert the input sequence into an iterator
|
||||||
|
self.iter = iter(seq)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
# This method is called when 'async for' is used
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
# This method is called for each iteration in an 'async for' loop
|
||||||
|
try:
|
||||||
|
return next(self.iter)
|
||||||
|
except StopIteration:
|
||||||
|
# This is the asynchronous equivalent of StopIteration
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
class ContentMock:
|
||||||
|
"""
|
||||||
|
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
||||||
|
This class provides the iter_chunked method which returns an async iterator of chunks.
|
||||||
|
"""
|
||||||
|
def __init__(self, chunks):
|
||||||
|
# Store the chunks that will be returned by the iterator
|
||||||
|
self.chunks = chunks
|
||||||
|
|
||||||
|
def iter_chunked(self, chunk_size):
|
||||||
|
# This method mimics aiohttp's content.iter_chunked()
|
||||||
|
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
||||||
|
return AsyncIteratorMock(self.chunks)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_success():
|
||||||
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
|
# Create a mock for content that returns an async iterator directly
|
||||||
|
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
|
||||||
|
mock_response.content = ContentMock(chunks)
|
||||||
|
|
||||||
|
mock_make_request = AsyncMock(return_value=mock_response)
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
# Mock file operations
|
||||||
|
mock_open = MagicMock()
|
||||||
|
mock_file = MagicMock()
|
||||||
|
mock_open.return_value.__enter__.return_value = mock_file
|
||||||
|
time_values = itertools.count(0, 0.1)
|
||||||
|
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
||||||
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
|
patch('builtins.open', mock_open), \
|
||||||
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'checkpoints',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message == 'Successfully downloaded model.sft'
|
||||||
|
assert result.status == 'completed'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
# Check progress callback calls
|
||||||
|
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
||||||
|
|
||||||
|
# Check initial call
|
||||||
|
mock_progress_callback.assert_any_call(
|
||||||
|
'checkpoints/model.sft',
|
||||||
|
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check final call
|
||||||
|
mock_progress_callback.assert_any_call(
|
||||||
|
'checkpoints/model.sft',
|
||||||
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify file writing
|
||||||
|
mock_file.write.assert_any_call(b'a' * 500)
|
||||||
|
mock_file.write.assert_any_call(b'b' * 300)
|
||||||
|
mock_file.write.assert_any_call(b'c' * 200)
|
||||||
|
|
||||||
|
# Verify request was made
|
||||||
|
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_url_request_failure():
|
||||||
|
# Mock dependencies
|
||||||
|
mock_response = AsyncMock(spec=ClientResponse)
|
||||||
|
mock_response.status = 404 # Simulate a "Not Found" error
|
||||||
|
mock_get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
# Mock the create_model_path function
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
||||||
|
# Mock the check_file_exists function to return None (file doesn't exist)
|
||||||
|
with patch('model_filemanager.check_file_exists', return_value=None):
|
||||||
|
# Call the function
|
||||||
|
result = await download_model(
|
||||||
|
mock_get,
|
||||||
|
'model.safetensors',
|
||||||
|
'http://example.com/model.safetensors',
|
||||||
|
'mock_directory',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the expected behavior
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
# Check that progress_callback was called with the correct arguments
|
||||||
|
mock_progress_callback.assert_any_call(
|
||||||
|
'mock_directory/model.safetensors',
|
||||||
|
DownloadModelStatus(
|
||||||
|
status=DownloadStatusType.PENDING,
|
||||||
|
progress_percentage=0,
|
||||||
|
message='Starting download of model.safetensors',
|
||||||
|
already_existed=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_progress_callback.assert_called_with(
|
||||||
|
'mock_directory/model.safetensors',
|
||||||
|
DownloadModelStatus(
|
||||||
|
status=DownloadStatusType.ERROR,
|
||||||
|
progress_percentage=0,
|
||||||
|
message='Failed to download model.safetensors. Status code: 404',
|
||||||
|
already_existed=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the get method was called with the correct URL
|
||||||
|
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_invalid_model_subdirectory():
|
||||||
|
|
||||||
|
mock_make_request = AsyncMock()
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'../bad_path',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message == 'Invalid model subdirectory'
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
|
||||||
|
# For create_model_path function
|
||||||
|
def test_create_model_path(tmp_path, monkeypatch):
|
||||||
|
mock_models_dir = tmp_path / "models"
|
||||||
|
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
||||||
|
|
||||||
|
model_name = "test_model.sft"
|
||||||
|
model_directory = "test_dir"
|
||||||
|
|
||||||
|
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
||||||
|
|
||||||
|
assert file_path == str(mock_models_dir / model_directory / model_name)
|
||||||
|
assert relative_path == f"{model_directory}/{model_name}"
|
||||||
|
assert os.path.exists(os.path.dirname(file_path))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||||
|
file_path = tmp_path / "existing_model.sft"
|
||||||
|
file_path.touch() # Create an empty file
|
||||||
|
|
||||||
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert result.message == "existing_model.sft already exists"
|
||||||
|
assert result.already_existed is True
|
||||||
|
|
||||||
|
mock_callback.assert_called_once_with(
|
||||||
|
"test/existing_model.sft",
|
||||||
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||||
|
file_path = tmp_path / "non_existing_model.sft"
|
||||||
|
|
||||||
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
mock_callback.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_download_progress_no_content_length():
|
||||||
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
|
mock_response.headers = {} # No Content-Length header
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
||||||
|
|
||||||
|
mock_callback = AsyncMock()
|
||||||
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
with patch('builtins.open', mock_open):
|
||||||
|
result = await track_download_progress(
|
||||||
|
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||||
|
mock_callback, 'models/model.sft', interval=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
# Check that progress was reported even without knowing the total size
|
||||||
|
mock_callback.assert_any_call(
|
||||||
|
'models/model.sft',
|
||||||
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_download_progress_interval():
|
||||||
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
||||||
|
|
||||||
|
mock_callback = AsyncMock()
|
||||||
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Create a mock time function that returns incremental float values
|
||||||
|
mock_time = MagicMock()
|
||||||
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||||
|
|
||||||
|
with patch('builtins.open', mock_open), \
|
||||||
|
patch('time.time', mock_time):
|
||||||
|
await track_download_progress(
|
||||||
|
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||||
|
mock_callback, 'models/model.sft', interval=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print out the actual call count and the arguments of each call for debugging
|
||||||
|
print(f"mock_callback was called {mock_callback.call_count} times")
|
||||||
|
for i, call in enumerate(mock_callback.call_args_list):
|
||||||
|
args, kwargs = call
|
||||||
|
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
||||||
|
|
||||||
|
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||||
|
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||||
|
|
||||||
|
# Verify the first and last calls
|
||||||
|
first_call = mock_callback.call_args_list[0]
|
||||||
|
assert first_call[0][1].status == "in_progress"
|
||||||
|
# Allow for some initial progress, but it should be less than 50%
|
||||||
|
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
|
||||||
|
|
||||||
|
last_call = mock_callback.call_args_list[-1]
|
||||||
|
assert last_call[0][1].status == "completed"
|
||||||
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
|
def test_valid_subdirectory():
|
||||||
|
assert validate_model_subdirectory("valid-model123") is True
|
||||||
|
|
||||||
|
def test_subdirectory_too_long():
|
||||||
|
assert validate_model_subdirectory("a" * 51) is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_double_dots():
|
||||||
|
assert validate_model_subdirectory("model/../unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_slash():
|
||||||
|
assert validate_model_subdirectory("model/unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_special_characters():
|
||||||
|
assert validate_model_subdirectory("model@unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_underscore_and_dash():
|
||||||
|
assert validate_model_subdirectory("valid_model-name") is True
|
||||||
|
|
||||||
|
def test_empty_subdirectory():
|
||||||
|
assert validate_model_subdirectory("") is False
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("filename, expected", [
|
||||||
|
("valid_model.safetensors", True),
|
||||||
|
("valid_model.sft", True),
|
||||||
|
("valid model.safetensors", True), # Test with space
|
||||||
|
("UPPERCASE_MODEL.SAFETENSORS", True),
|
||||||
|
("model_with.multiple.dots.pt", False),
|
||||||
|
("", False), # Empty string
|
||||||
|
("../../../etc/passwd", False), # Path traversal attempt
|
||||||
|
("/etc/passwd", False), # Absolute path
|
||||||
|
("\\windows\\system32\\config\\sam", False), # Windows path
|
||||||
|
(".hidden_file.pt", False), # Hidden file
|
||||||
|
("invalid<char>.ckpt", False), # Invalid character
|
||||||
|
("invalid?.ckpt", False), # Another invalid character
|
||||||
|
("very" * 100 + ".safetensors", False), # Too long filename
|
||||||
|
("\nmodel_with_newline.pt", False), # Newline character
|
||||||
|
("model_with_emoji😊.pt", False), # Emoji in filename
|
||||||
|
])
|
||||||
|
def test_validate_filename(filename, expected):
|
||||||
|
assert validate_filename(filename) == expected
|
||||||
@@ -1 +1,3 @@
|
|||||||
pytest>=7.8.0
|
pytest>=7.8.0
|
||||||
|
pytest-aiohttp
|
||||||
|
pytest-asyncio
|
||||||
|
|||||||
115
tests-unit/server/routes/internal_routes_test.py
Normal file
115
tests-unit/server/routes/internal_routes_test.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
from api_server.services.file_service import FileService
|
||||||
|
from folder_paths import models_dir, user_directory, output_directory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def internal_routes():
|
||||||
|
return InternalRoutes()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def aiohttp_client_factory(aiohttp_client, internal_routes):
|
||||||
|
async def _get_client():
|
||||||
|
app = internal_routes.get_app()
|
||||||
|
return await aiohttp_client(app)
|
||||||
|
return _get_client
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_files_valid_directory(aiohttp_client_factory, internal_routes):
|
||||||
|
mock_file_list = [
|
||||||
|
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||||
|
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||||
|
]
|
||||||
|
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
||||||
|
client = await aiohttp_client_factory()
|
||||||
|
resp = await client.get('/files?directory=models')
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert 'files' in data
|
||||||
|
assert len(data['files']) == 2
|
||||||
|
assert data['files'] == mock_file_list
|
||||||
|
|
||||||
|
# Check other valid directories
|
||||||
|
resp = await client.get('/files?directory=user')
|
||||||
|
assert resp.status == 200
|
||||||
|
resp = await client.get('/files?directory=output')
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_files_invalid_directory(aiohttp_client_factory, internal_routes):
|
||||||
|
internal_routes.file_service.list_files = MagicMock(side_effect=ValueError("Invalid directory key"))
|
||||||
|
client = await aiohttp_client_factory()
|
||||||
|
resp = await client.get('/files?directory=invalid')
|
||||||
|
assert resp.status == 400
|
||||||
|
data = await resp.json()
|
||||||
|
assert 'error' in data
|
||||||
|
assert data['error'] == "Invalid directory key"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_files_exception(aiohttp_client_factory, internal_routes):
|
||||||
|
internal_routes.file_service.list_files = MagicMock(side_effect=Exception("Unexpected error"))
|
||||||
|
client = await aiohttp_client_factory()
|
||||||
|
resp = await client.get('/files?directory=models')
|
||||||
|
assert resp.status == 500
|
||||||
|
data = await resp.json()
|
||||||
|
assert 'error' in data
|
||||||
|
assert data['error'] == "Unexpected error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_files_no_directory_param(aiohttp_client_factory, internal_routes):
|
||||||
|
mock_file_list = []
|
||||||
|
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
||||||
|
client = await aiohttp_client_factory()
|
||||||
|
resp = await client.get('/files')
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert 'files' in data
|
||||||
|
assert len(data['files']) == 0
|
||||||
|
|
||||||
|
def test_setup_routes(internal_routes):
|
||||||
|
internal_routes.setup_routes()
|
||||||
|
routes = internal_routes.routes
|
||||||
|
assert any(route.method == 'GET' and str(route.path) == '/files' for route in routes)
|
||||||
|
|
||||||
|
def test_get_app(internal_routes):
|
||||||
|
app = internal_routes.get_app()
|
||||||
|
assert isinstance(app, web.Application)
|
||||||
|
assert internal_routes._app is not None
|
||||||
|
|
||||||
|
def test_get_app_reuse(internal_routes):
|
||||||
|
app1 = internal_routes.get_app()
|
||||||
|
app2 = internal_routes.get_app()
|
||||||
|
assert app1 is app2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
|
||||||
|
client = await aiohttp_client_factory()
|
||||||
|
try:
|
||||||
|
resp = await client.get('/files')
|
||||||
|
print(f"Response received: status {resp.status}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception occurred during GET request: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
assert resp.status != 404, "Route /files does not exist"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_service_initialization():
|
||||||
|
with patch('api_server.routes.internal.internal_routes.FileService') as MockFileService:
|
||||||
|
# Create a mock instance
|
||||||
|
mock_file_service_instance = MagicMock(spec=FileService)
|
||||||
|
MockFileService.return_value = mock_file_service_instance
|
||||||
|
internal_routes = InternalRoutes()
|
||||||
|
|
||||||
|
# Check if FileService was initialized with the correct parameters
|
||||||
|
MockFileService.assert_called_once_with({
|
||||||
|
"models": models_dir,
|
||||||
|
"user": user_directory,
|
||||||
|
"output": output_directory
|
||||||
|
})
|
||||||
|
|
||||||
|
# Verify that the file_service attribute of InternalRoutes is set
|
||||||
|
assert internal_routes.file_service == mock_file_service_instance
|
||||||
54
tests-unit/server/services/file_service_test.py
Normal file
54
tests-unit/server/services/file_service_test.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from api_server.services.file_service import FileService
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_file_system_ops():
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def file_service(mock_file_system_ops):
|
||||||
|
allowed_directories = {
|
||||||
|
"models": "/path/to/models",
|
||||||
|
"user": "/path/to/user",
|
||||||
|
"output": "/path/to/output"
|
||||||
|
}
|
||||||
|
return FileService(allowed_directories, file_system_ops=mock_file_system_ops)
|
||||||
|
|
||||||
|
def test_list_files_valid_directory(file_service, mock_file_system_ops):
|
||||||
|
mock_file_system_ops.walk_directory.return_value = [
|
||||||
|
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||||
|
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = file_service.list_files("models")
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["name"] == "file1.txt"
|
||||||
|
assert result[1]["name"] == "dir1"
|
||||||
|
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||||
|
|
||||||
|
def test_list_files_invalid_directory(file_service):
|
||||||
|
# Does not support walking directories outside of the allowed directories
|
||||||
|
with pytest.raises(ValueError, match="Invalid directory key"):
|
||||||
|
file_service.list_files("invalid_key")
|
||||||
|
|
||||||
|
def test_list_files_empty_directory(file_service, mock_file_system_ops):
|
||||||
|
mock_file_system_ops.walk_directory.return_value = []
|
||||||
|
|
||||||
|
result = file_service.list_files("models")
|
||||||
|
|
||||||
|
assert len(result) == 0
|
||||||
|
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("directory_key", ["models", "user", "output"])
|
||||||
|
def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, directory_key):
|
||||||
|
mock_file_system_ops.walk_directory.return_value = [
|
||||||
|
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = file_service.list_files(directory_key)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["name"] == f"file_{directory_key}.txt"
|
||||||
|
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
||||||
42
tests-unit/server/utils/file_operations_test.py
Normal file
42
tests-unit/server/utils/file_operations_test.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import pytest
|
||||||
|
from typing import List
|
||||||
|
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem, is_file_info
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_directory(tmp_path):
|
||||||
|
# Create a temporary directory structure
|
||||||
|
dir1 = tmp_path / "dir1"
|
||||||
|
dir2 = tmp_path / "dir2"
|
||||||
|
dir1.mkdir()
|
||||||
|
dir2.mkdir()
|
||||||
|
(dir1 / "file1.txt").write_text("content1")
|
||||||
|
(dir2 / "file2.txt").write_text("content2")
|
||||||
|
(tmp_path / "file3.txt").write_text("content3")
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
def test_walk_directory(temp_directory):
|
||||||
|
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
||||||
|
|
||||||
|
assert len(result) == 5 # 2 directories and 3 files
|
||||||
|
|
||||||
|
files = [item for item in result if item['type'] == 'file']
|
||||||
|
dirs = [item for item in result if item['type'] == 'directory']
|
||||||
|
|
||||||
|
assert len(files) == 3
|
||||||
|
assert len(dirs) == 2
|
||||||
|
|
||||||
|
file_names = {file['name'] for file in files}
|
||||||
|
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
|
||||||
|
|
||||||
|
dir_names = {dir['name'] for dir in dirs}
|
||||||
|
assert dir_names == {'dir1', 'dir2'}
|
||||||
|
|
||||||
|
def test_walk_directory_empty(tmp_path):
|
||||||
|
result = FileSystemOperations.walk_directory(str(tmp_path))
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_walk_directory_file_size(temp_directory):
|
||||||
|
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
||||||
|
files = [item for item in result if is_file_info(item)]
|
||||||
|
for file in files:
|
||||||
|
assert file['size'] > 0 # Assuming all files have some content
|
||||||
4
tests/inference/extra_model_paths.yaml
Normal file
4
tests/inference/extra_model_paths.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# Config for testing nodes
|
||||||
|
testing:
|
||||||
|
custom_nodes: tests/inference/testing_nodes
|
||||||
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user