Compare commits
82 Commits
yoland68-m
...
v0.3.32
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0cf2e46b17 | ||
|
|
094e9ef126 | ||
|
|
1271c4ef9d | ||
|
|
d9c80a85e5 | ||
|
|
3e62c5513a | ||
|
|
cd18582578 | ||
|
|
80a44b97f5 | ||
|
|
9187a09483 | ||
|
|
3041e5c354 | ||
|
|
7689917113 | ||
|
|
486ad8fdc5 | ||
|
|
065d855f14 | ||
|
|
530494588d | ||
|
|
2ab9618732 | ||
|
|
d9a87c1e6a | ||
|
|
551fe8dcee | ||
|
|
ff99861650 | ||
|
|
8d0661d0ba | ||
|
|
6d32dc049e | ||
|
|
aa9d759df3 | ||
|
|
c6c19e9980 | ||
|
|
08ff5fa08a | ||
|
|
4ca3d84277 | ||
|
|
39c27a3705 | ||
|
|
b1c7291569 | ||
|
|
dbc726f80c | ||
|
|
7ee96455e2 | ||
|
|
0a66d4b0af | ||
|
|
5c5457a4ef | ||
|
|
45503f6499 | ||
|
|
005a91ce2b | ||
|
|
68f0d35296 | ||
|
|
83d04717b6 | ||
|
|
7d329771f9 | ||
|
|
c15909bb62 | ||
|
|
772b4c5945 | ||
|
|
5a50c3c7e5 | ||
|
|
30159a7fe6 | ||
|
|
cb9ac3db58 | ||
|
|
8115a7895b | ||
|
|
c8cd7ad795 | ||
|
|
542b4b36b6 | ||
|
|
ac10a0d69e | ||
|
|
0dcc75ca54 | ||
|
|
b685b8a4e0 | ||
|
|
23e39f2ba7 | ||
|
|
78992c4b25 | ||
|
|
f935d42d8e | ||
|
|
a97f2f850a | ||
|
|
5acb705857 | ||
|
|
5c80da31db | ||
|
|
e2eed9eb9b | ||
|
|
11b68ebd22 | ||
|
|
188b383c35 | ||
|
|
2c1d686ec6 | ||
|
|
e8ddc2be95 | ||
|
|
dea1c7474a | ||
|
|
154f2911aa | ||
|
|
3eaad0590e | ||
|
|
7eaff81be1 | ||
|
|
21a11ef817 | ||
|
|
552615235d | ||
|
|
0738e4ea5d | ||
|
|
92cdc692f4 | ||
|
|
2d6805ce57 | ||
|
|
a8f63c0d5b | ||
|
|
454a635c1b | ||
|
|
966c43ce26 | ||
|
|
3ab231f01f | ||
|
|
1f3fba2af5 | ||
|
|
5d0d4ee98a | ||
|
|
9d57b8afd8 | ||
|
|
5d51794607 | ||
|
|
ce22f687cc | ||
|
|
b6fd3ffd10 | ||
|
|
11b72c9c55 | ||
|
|
2c735c13b4 | ||
|
|
fd27494441 | ||
|
|
f43e1d7f41 | ||
|
|
4486b0d0ff | ||
|
|
636d4bfb89 | ||
|
|
dc300a4569 |
@@ -63,6 +63,11 @@ except:
|
|||||||
print("checking out master branch") # noqa: T201
|
print("checking out master branch") # noqa: T201
|
||||||
branch = repo.lookup_branch('master')
|
branch = repo.lookup_branch('master')
|
||||||
if branch is None:
|
if branch is None:
|
||||||
|
try:
|
||||||
|
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||||
|
except:
|
||||||
|
print("pulling.") # noqa: T201
|
||||||
|
pull(repo)
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||||
repo.checkout(ref)
|
repo.checkout(ref)
|
||||||
branch = repo.lookup_branch('master')
|
branch = repo.lookup_branch('master')
|
||||||
|
|||||||
12
.github/workflows/stable-release.yml
vendored
12
.github/workflows/stable-release.yml
vendored
@@ -12,7 +12,7 @@ on:
|
|||||||
description: 'CUDA version'
|
description: 'CUDA version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "126"
|
default: "128"
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'Python minor version'
|
description: 'Python minor version'
|
||||||
required: true
|
required: true
|
||||||
@@ -22,7 +22,7 @@ on:
|
|||||||
description: 'Python patch version'
|
description: 'Python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "10"
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -36,7 +36,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.git_tag }}
|
ref: ${{ inputs.git_tag }}
|
||||||
fetch-depth: 0
|
fetch-depth: 150
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- uses: actions/cache/restore@v4
|
- uses: actions/cache/restore@v4
|
||||||
id: cache
|
id: cache
|
||||||
@@ -70,7 +70,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
mkdir ComfyUI_windows_portable
|
||||||
mv python_embeded ComfyUI_windows_portable
|
mv python_embeded ComfyUI_windows_portable
|
||||||
@@ -85,12 +85,14 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||||
|
|
||||||
|
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
||||||
|
|
||||||
ls
|
ls
|
||||||
|
|
||||||
- name: Upload binaries to release
|
- name: Upload binaries to release
|
||||||
|
|||||||
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
path: "ComfyUI"
|
path: "ComfyUI"
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.9'
|
python-version: '3.10'
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
|||||||
56
.github/workflows/update-api-stubs.yml
vendored
Normal file
56
.github/workflows/update-api-stubs.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
name: Generate Pydantic Stubs from api.comfy.org
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * 1'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
generate-models:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install 'datamodel-code-generator[http]'
|
||||||
|
npm install @redocly/cli
|
||||||
|
|
||||||
|
- name: Download OpenAPI spec
|
||||||
|
run: |
|
||||||
|
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||||
|
|
||||||
|
- name: Filter OpenAPI spec with Redocly
|
||||||
|
run: |
|
||||||
|
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||||
|
|
||||||
|
- name: Generate API models
|
||||||
|
run: |
|
||||||
|
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||||
|
|
||||||
|
- name: Check for changes
|
||||||
|
id: git-check
|
||||||
|
run: |
|
||||||
|
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Create Pull Request
|
||||||
|
if: steps.git-check.outputs.changes == 'true'
|
||||||
|
uses: peter-evans/create-pull-request@v5
|
||||||
|
with:
|
||||||
|
commit-message: 'chore: update API models from OpenAPI spec'
|
||||||
|
title: 'Update API models from api.comfy.org'
|
||||||
|
body: |
|
||||||
|
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
||||||
|
|
||||||
|
Generated automatically by the a Github workflow.
|
||||||
|
branch: update-api-stubs
|
||||||
|
delete-branch: true
|
||||||
|
base: master
|
||||||
@@ -17,7 +17,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "126"
|
default: "128"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@@ -29,7 +29,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "10"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||||
|
|||||||
12
.github/workflows/windows_release_package.yml
vendored
12
.github/workflows/windows_release_package.yml
vendored
@@ -7,7 +7,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "126"
|
default: "128"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@@ -19,7 +19,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "10"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
@@ -50,7 +50,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 150
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- shell: bash
|
- shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -67,7 +67,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
mkdir ComfyUI_windows_portable
|
||||||
mv python_embeded ComfyUI_windows_portable
|
mv python_embeded ComfyUI_windows_portable
|
||||||
@@ -82,12 +82,14 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||||
|
|
||||||
|
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
||||||
|
|
||||||
ls
|
ls
|
||||||
|
|
||||||
- name: Upload binaries to release
|
- name: Upload binaries to release
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -21,3 +21,6 @@ venv/
|
|||||||
*.log
|
*.log
|
||||||
web_custom_versions/
|
web_custom_versions/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
openapi.yaml
|
||||||
|
filtered-openapi.yaml
|
||||||
|
uv.lock
|
||||||
|
|||||||
26
CODEOWNERS
26
CODEOWNERS
@@ -5,20 +5,20 @@
|
|||||||
# Inlined the team members for now.
|
# Inlined the team members for now.
|
||||||
|
|
||||||
# Maintainers
|
# Maintainers
|
||||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
|
|
||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
|
|
||||||
# Node developers
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
|
|||||||
24
README.md
24
README.md
@@ -49,7 +49,6 @@ Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon,
|
|||||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||||
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Image Models
|
- Image Models
|
||||||
@@ -99,6 +98,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
|
|
||||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
|
|
||||||
|
## Release Process
|
||||||
|
|
||||||
|
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||||
|
|
||||||
|
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||||
|
- Releases a new stable version (e.g., v0.7.0)
|
||||||
|
- Serves as the foundation for the desktop release
|
||||||
|
|
||||||
|
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||||
|
- Builds a new release using the latest stable core version
|
||||||
|
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
|
||||||
|
|
||||||
|
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||||
|
- Weekly frontend updates are merged into the core repository
|
||||||
|
- Features are frozen for the upcoming core release
|
||||||
|
- Development continues for the next release cycle
|
||||||
|
|
||||||
## Shortcuts
|
## Shortcuts
|
||||||
|
|
||||||
| Keybind | Explanation |
|
| Keybind | Explanation |
|
||||||
@@ -149,8 +165,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
|
|||||||
|
|
||||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||||
|
|
||||||
If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643)
|
|
||||||
|
|
||||||
#### How do I share models between another UI and ComfyUI?
|
#### How do I share models between another UI and ComfyUI?
|
||||||
|
|
||||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||||
@@ -216,9 +230,9 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
|
|||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||||
|
|
||||||
|
|||||||
@@ -93,16 +93,20 @@ class CustomNodeManager:
|
|||||||
|
|
||||||
def add_routes(self, routes, webapp, loadedModules):
|
def add_routes(self, routes, webapp, loadedModules):
|
||||||
|
|
||||||
|
example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
|
||||||
|
|
||||||
@routes.get("/workflow_templates")
|
@routes.get("/workflow_templates")
|
||||||
async def get_workflow_templates(request):
|
async def get_workflow_templates(request):
|
||||||
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
||||||
files = [
|
|
||||||
file
|
files = []
|
||||||
for folder in folder_paths.get_folder_paths("custom_nodes")
|
|
||||||
for file in glob.glob(
|
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||||
os.path.join(folder, "*/example_workflows/*.json")
|
for folder_name in example_workflow_folder_names:
|
||||||
)
|
pattern = os.path.join(folder, f"*/{folder_name}/*.json")
|
||||||
]
|
matched_files = glob.glob(pattern)
|
||||||
|
files.extend(matched_files)
|
||||||
|
|
||||||
workflow_templates_dict = (
|
workflow_templates_dict = (
|
||||||
{}
|
{}
|
||||||
) # custom_nodes folder name -> example workflow names
|
) # custom_nodes folder name -> example workflow names
|
||||||
@@ -118,8 +122,15 @@ class CustomNodeManager:
|
|||||||
|
|
||||||
# Serve workflow templates from custom nodes.
|
# Serve workflow templates from custom nodes.
|
||||||
for module_name, module_dir in loadedModules:
|
for module_name, module_dir in loadedModules:
|
||||||
workflows_dir = os.path.join(module_dir, "example_workflows")
|
for folder_name in example_workflow_folder_names:
|
||||||
|
workflows_dir = os.path.join(module_dir, folder_name)
|
||||||
|
|
||||||
if os.path.exists(workflows_dir):
|
if os.path.exists(workflows_dir):
|
||||||
|
if folder_name != "example_workflows":
|
||||||
|
logging.debug(
|
||||||
|
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
|
||||||
|
folder_name, module_name)
|
||||||
|
|
||||||
webapp.add_routes(
|
webapp.add_routes(
|
||||||
[
|
[
|
||||||
web.static(
|
web.static(
|
||||||
|
|||||||
@@ -197,6 +197,112 @@ class UserManager():
|
|||||||
|
|
||||||
return web.json_response(results)
|
return web.json_response(results)
|
||||||
|
|
||||||
|
@routes.get("/v2/userdata")
|
||||||
|
async def list_userdata_v2(request):
|
||||||
|
"""
|
||||||
|
List files and directories in a user's data directory.
|
||||||
|
|
||||||
|
This endpoint provides a structured listing of contents within a specified
|
||||||
|
subdirectory of the user's data storage.
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
- path (optional): The relative path within the user's data directory
|
||||||
|
to list. Defaults to the root ('').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
|
||||||
|
- 404: If the requested path does not exist.
|
||||||
|
- 403: If the user is invalid.
|
||||||
|
- 500: If there is an error reading the directory contents.
|
||||||
|
- 200: JSON response containing a list of file and directory objects.
|
||||||
|
Each object includes:
|
||||||
|
- name: The name of the file or directory.
|
||||||
|
- type: 'file' or 'directory'.
|
||||||
|
- path: The relative path from the user's data root.
|
||||||
|
- size (for files): The size in bytes.
|
||||||
|
- modified (for files): The last modified timestamp (Unix epoch).
|
||||||
|
"""
|
||||||
|
requested_rel_path = request.rel_url.query.get('path', '')
|
||||||
|
|
||||||
|
# URL-decode the path parameter
|
||||||
|
try:
|
||||||
|
requested_rel_path = parse.unquote(requested_rel_path)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
|
||||||
|
return web.Response(status=400, text="Invalid characters in path parameter")
|
||||||
|
|
||||||
|
|
||||||
|
# Check user validity and get the absolute path for the requested directory
|
||||||
|
try:
|
||||||
|
base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
|
||||||
|
|
||||||
|
if requested_rel_path:
|
||||||
|
target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
|
||||||
|
else:
|
||||||
|
target_abs_path = base_user_path
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
# Invalid user detected by get_request_user_id inside get_request_user_filepath
|
||||||
|
logging.warning(f"Access denied for user: {e}")
|
||||||
|
return web.Response(status=403, text="Invalid user specified in request")
|
||||||
|
|
||||||
|
|
||||||
|
if not target_abs_path:
|
||||||
|
# Path traversal or other issue detected by get_request_user_filepath
|
||||||
|
return web.Response(status=400, text="Invalid path requested")
|
||||||
|
|
||||||
|
# Handle cases where the user directory or target path doesn't exist
|
||||||
|
if not os.path.exists(target_abs_path):
|
||||||
|
# Check if it's the base user directory that's missing (new user case)
|
||||||
|
if target_abs_path == base_user_path:
|
||||||
|
# It's okay if the base user directory doesn't exist yet, return empty list
|
||||||
|
return web.json_response([])
|
||||||
|
else:
|
||||||
|
# A specific subdirectory was requested but doesn't exist
|
||||||
|
return web.Response(status=404, text="Requested path not found")
|
||||||
|
|
||||||
|
if not os.path.isdir(target_abs_path):
|
||||||
|
return web.Response(status=400, text="Requested path is not a directory")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
for root, dirs, files in os.walk(target_abs_path, topdown=True):
|
||||||
|
# Process directories
|
||||||
|
for dir_name in dirs:
|
||||||
|
dir_path = os.path.join(root, dir_name)
|
||||||
|
rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
|
||||||
|
results.append({
|
||||||
|
"name": dir_name,
|
||||||
|
"path": rel_path,
|
||||||
|
"type": "directory"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process files
|
||||||
|
for file_name in files:
|
||||||
|
file_path = os.path.join(root, file_name)
|
||||||
|
rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
|
||||||
|
entry_info = {
|
||||||
|
"name": file_name,
|
||||||
|
"path": rel_path,
|
||||||
|
"type": "file"
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
|
||||||
|
entry_info["size"] = stats.st_size
|
||||||
|
entry_info["modified"] = stats.st_mtime
|
||||||
|
except OSError as stat_error:
|
||||||
|
logging.warning(f"Could not stat file {file_path}: {stat_error}")
|
||||||
|
pass # Include file with available info
|
||||||
|
results.append(entry_info)
|
||||||
|
except OSError as e:
|
||||||
|
logging.error(f"Error listing directory {target_abs_path}: {e}")
|
||||||
|
return web.Response(status=500, text="Error reading directory contents")
|
||||||
|
|
||||||
|
# Sort results alphabetically, directories first then files
|
||||||
|
results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
|
||||||
|
|
||||||
|
return web.json_response(results)
|
||||||
|
|
||||||
def get_user_data_path(request, check_exists = False, param = "file"):
|
def get_user_data_path(request, check_exists = False, param = "file"):
|
||||||
file = request.match_info.get(param, None)
|
file = request.match_info.get(param, None)
|
||||||
if not file:
|
if not file:
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
|
|||||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||||
|
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
@@ -127,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
|||||||
|
|
||||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||||
|
|
||||||
|
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
|
|
||||||
@@ -146,6 +148,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
|||||||
|
|
||||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||||
|
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
@@ -190,6 +193,13 @@ parser.add_argument("--user-directory", type=is_valid_directory, default=None, h
|
|||||||
|
|
||||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--comfy-api-base",
|
||||||
|
type=str,
|
||||||
|
default="https://api.comfy.org",
|
||||||
|
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||||
|
)
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class Output:
|
|||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||||
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
image = image.movedim(-1, 1)
|
image = image.movedim(-1, 1)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Comfy-specific type hinting"""
|
"""Comfy-specific type hinting"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict, Optional
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -48,6 +48,7 @@ class IO(StrEnum):
|
|||||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||||
BBOX = "BBOX"
|
BBOX = "BBOX"
|
||||||
SEGS = "SEGS"
|
SEGS = "SEGS"
|
||||||
|
VIDEO = "VIDEO"
|
||||||
|
|
||||||
ANY = "*"
|
ANY = "*"
|
||||||
"""Always matches any type, but at a price.
|
"""Always matches any type, but at a price.
|
||||||
@@ -115,6 +116,15 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||||
tooltip: NotRequired[str]
|
tooltip: NotRequired[str]
|
||||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||||
|
socketless: NotRequired[bool]
|
||||||
|
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
|
||||||
|
Available from frontend v1.17.5
|
||||||
|
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
|
||||||
|
"""
|
||||||
|
widgetType: NotRequired[str]
|
||||||
|
"""Specifies a type to be used for widget initialization if different from the input type.
|
||||||
|
Available from frontend v1.18.0
|
||||||
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
|
||||||
# class InputTypeNumber(InputTypeOptions):
|
# class InputTypeNumber(InputTypeOptions):
|
||||||
# default: float | int
|
# default: float | int
|
||||||
min: NotRequired[float]
|
min: NotRequired[float]
|
||||||
@@ -224,6 +234,8 @@ class ComfyNodeABC(ABC):
|
|||||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||||
DEPRECATED: bool
|
DEPRECATED: bool
|
||||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||||
|
API_NODE: Optional[bool]
|
||||||
|
"""Flags a node as an API node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -262,7 +274,7 @@ class ComfyNodeABC(ABC):
|
|||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||||
"""
|
"""
|
||||||
OUTPUT_IS_LIST: tuple[bool]
|
OUTPUT_IS_LIST: tuple[bool, ...]
|
||||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||||
|
|
||||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||||
@@ -281,7 +293,7 @@ class ComfyNodeABC(ABC):
|
|||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
RETURN_TYPES: tuple[IO]
|
RETURN_TYPES: tuple[IO, ...]
|
||||||
"""A tuple representing the outputs of this node.
|
"""A tuple representing the outputs of this node.
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
@@ -290,12 +302,12 @@ class ComfyNodeABC(ABC):
|
|||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
|
||||||
"""
|
"""
|
||||||
RETURN_NAMES: tuple[str]
|
RETURN_NAMES: tuple[str, ...]
|
||||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
|
||||||
"""
|
"""
|
||||||
OUTPUT_TOOLTIPS: tuple[str]
|
OUTPUT_TOOLTIPS: tuple[str, ...]
|
||||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||||
FUNCTION: str
|
FUNCTION: str
|
||||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||||
|
|||||||
@@ -736,6 +736,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
model_options = model_options.copy()
|
||||||
if "global_average_pooling" not in model_options:
|
if "global_average_pooling" not in model_options:
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module):
|
|||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
x = self.patch_embeddings(pixel_values)
|
x = self.patch_embeddings(pixel_values)
|
||||||
# TODO: mask_token?
|
# TODO: mask_token?
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -1345,28 +1345,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal
|
|||||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
old_d = None
|
old_d = None
|
||||||
|
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
if cfg_pp:
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if cfg_pp:
|
||||||
|
d = to_d(x, sigmas[i], uncond_denoised)
|
||||||
|
else:
|
||||||
d = to_d(x, sigmas[i], denoised)
|
d = to_d(x, sigmas[i], denoised)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
dt = sigmas[i + 1] - sigmas[i]
|
dt = sigmas[i + 1] - sigmas[i]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# Euler method
|
# Euler method
|
||||||
|
if cfg_pp:
|
||||||
|
x = denoised + d * sigmas[i + 1]
|
||||||
|
else:
|
||||||
x = x + d * dt
|
x = x + d * dt
|
||||||
else:
|
else:
|
||||||
# Gradient estimation
|
# Gradient estimation
|
||||||
|
if cfg_pp:
|
||||||
|
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||||
|
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||||
|
else:
|
||||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||||
x = x + d_bar * dt
|
x = x + d_bar * dt
|
||||||
old_d = d
|
old_d = d
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||||
|
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||||
"""
|
"""
|
||||||
|
|||||||
183
comfy/ldm/chroma/layers.py
Normal file
183
comfy/ldm/chroma/layers.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from comfy.ldm.flux.math import attention
|
||||||
|
from comfy.ldm.flux.layers import (
|
||||||
|
MLPEmbedder,
|
||||||
|
RMSNorm,
|
||||||
|
QKNorm,
|
||||||
|
SelfAttention,
|
||||||
|
ModulationOut,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaModulationOut(ModulationOut):
|
||||||
|
@classmethod
|
||||||
|
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||||
|
return cls(
|
||||||
|
shift=tensor[:, offset : offset + 1, :],
|
||||||
|
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||||
|
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Approximator(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||||
|
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||||
|
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
# Get the device of the module (assumes all parameters are on the same device)
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.in_proj(x)
|
||||||
|
|
||||||
|
for layer, norms in zip(self.layers, self.norms):
|
||||||
|
x = x + layer(norms(x))
|
||||||
|
|
||||||
|
x = self.out_proj(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleStreamBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.img_mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.txt_mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
|
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||||
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
|
# prepare image for attention
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||||
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
|
# prepare txt for attention
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||||
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
|
# run actual attention
|
||||||
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
|
torch.cat((txt_v, img_v), dim=2),
|
||||||
|
pe=pe, mask=attn_mask)
|
||||||
|
|
||||||
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
|
# calculate the img bloks
|
||||||
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
|
# calculate the txt bloks
|
||||||
|
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
|
if txt.dtype == torch.float16:
|
||||||
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStreamBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A DiT block with parallel linear layers as described in
|
||||||
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qk_scale: float = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
# qkv and mlp_in
|
||||||
|
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||||
|
# proj and mlp_out
|
||||||
|
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||||
|
mod = vec
|
||||||
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||||
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
|
x += mod.gate * output
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LastLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||||
|
shift, scale = vec
|
||||||
|
shift = shift.squeeze(1)
|
||||||
|
scale = scale.squeeze(1)
|
||||||
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
271
comfy/ldm/chroma/model.py
Normal file
271
comfy/ldm/chroma/model.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
#Original code can be found on: https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
from comfy.ldm.flux.layers import (
|
||||||
|
EmbedND,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .layers import (
|
||||||
|
DoubleStreamBlock,
|
||||||
|
LastLayer,
|
||||||
|
SingleStreamBlock,
|
||||||
|
Approximator,
|
||||||
|
ChromaModulationOut,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChromaParams:
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
context_in_dim: int
|
||||||
|
hidden_size: int
|
||||||
|
mlp_ratio: float
|
||||||
|
num_heads: int
|
||||||
|
depth: int
|
||||||
|
depth_single_blocks: int
|
||||||
|
axes_dim: list
|
||||||
|
theta: int
|
||||||
|
patch_size: int
|
||||||
|
qkv_bias: bool
|
||||||
|
in_dim: int
|
||||||
|
out_dim: int
|
||||||
|
hidden_dim: int
|
||||||
|
n_layers: int
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Chroma(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer model for flow matching on sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
params = ChromaParams(**kwargs)
|
||||||
|
self.params = params
|
||||||
|
self.patch_size = params.patch_size
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = params.out_channels
|
||||||
|
if params.hidden_size % params.num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||||
|
)
|
||||||
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
|
if sum(params.axes_dim) != pe_dim:
|
||||||
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||||
|
self.hidden_size = params.hidden_size
|
||||||
|
self.num_heads = params.num_heads
|
||||||
|
self.in_dim = params.in_dim
|
||||||
|
self.out_dim = params.out_dim
|
||||||
|
self.hidden_dim = params.hidden_dim
|
||||||
|
self.n_layers = params.n_layers
|
||||||
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
|
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
# set as nn identity for now, will overwrite it later.
|
||||||
|
self.distilled_guidance_layer = Approximator(
|
||||||
|
in_dim=self.in_dim,
|
||||||
|
hidden_dim=self.hidden_dim,
|
||||||
|
out_dim=self.out_dim,
|
||||||
|
n_layers=self.n_layers,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
qkv_bias=params.qkv_bias,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(params.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(params.depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_layer:
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.skip_mmdit = []
|
||||||
|
self.skip_dit = []
|
||||||
|
self.lite = False
|
||||||
|
|
||||||
|
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
|
||||||
|
# This function slices up the modulations tensor which has the following layout:
|
||||||
|
# single : num_single_blocks * 3 elements
|
||||||
|
# double_img : num_double_blocks * 6 elements
|
||||||
|
# double_txt : num_double_blocks * 6 elements
|
||||||
|
# final : 2 elements
|
||||||
|
if block_type == "final":
|
||||||
|
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
|
||||||
|
single_block_count = self.params.depth_single_blocks
|
||||||
|
double_block_count = self.params.depth
|
||||||
|
offset = 3 * idx
|
||||||
|
if block_type == "single":
|
||||||
|
return ChromaModulationOut.from_offset(tensor, offset)
|
||||||
|
# Double block modulations are 6 elements so we double 3 * idx.
|
||||||
|
offset *= 2
|
||||||
|
if block_type in {"double_img", "double_txt"}:
|
||||||
|
# Advance past the single block modulations.
|
||||||
|
offset += 3 * single_block_count
|
||||||
|
if block_type == "double_txt":
|
||||||
|
# Advance past the double block img modulations.
|
||||||
|
offset += 6 * double_block_count
|
||||||
|
return (
|
||||||
|
ChromaModulationOut.from_offset(tensor, offset),
|
||||||
|
ChromaModulationOut.from_offset(tensor, offset + 3),
|
||||||
|
)
|
||||||
|
raise ValueError("Bad block_type")
|
||||||
|
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
control = None,
|
||||||
|
transformer_options={},
|
||||||
|
attn_mask: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
|
||||||
|
# distilled vector guidance
|
||||||
|
mod_index_length = 344
|
||||||
|
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
|
||||||
|
# guidance = guidance *
|
||||||
|
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
||||||
|
|
||||||
|
# get all modulation index
|
||||||
|
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
|
||||||
|
# we need to broadcast the modulation index here so each batch has all of the index
|
||||||
|
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
||||||
|
# and we need to broadcast timestep and guidance along too
|
||||||
|
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
|
||||||
|
# then and only then we could concatenate it together
|
||||||
|
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
|
||||||
|
|
||||||
|
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||||
|
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
if i not in self.skip_mmdit:
|
||||||
|
double_mod = (
|
||||||
|
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||||
|
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
||||||
|
)
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"], out["txt"] = block(img=args["img"],
|
||||||
|
txt=args["txt"],
|
||||||
|
vec=args["vec"],
|
||||||
|
pe=args["pe"],
|
||||||
|
attn_mask=args.get("attn_mask"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
|
"txt": txt,
|
||||||
|
"vec": double_mod,
|
||||||
|
"pe": pe,
|
||||||
|
"attn_mask": attn_mask},
|
||||||
|
{"original_block": block_wrap})
|
||||||
|
txt = out["txt"]
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img, txt = block(img=img,
|
||||||
|
txt=txt,
|
||||||
|
vec=double_mod,
|
||||||
|
pe=pe,
|
||||||
|
attn_mask=attn_mask)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_i = control.get("input")
|
||||||
|
if i < len(control_i):
|
||||||
|
add = control_i[i]
|
||||||
|
if add is not None:
|
||||||
|
img += add
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.single_blocks):
|
||||||
|
if i not in self.skip_dit:
|
||||||
|
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||||
|
if ("single_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"],
|
||||||
|
vec=args["vec"],
|
||||||
|
pe=args["pe"],
|
||||||
|
attn_mask=args.get("attn_mask"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
|
"vec": single_mod,
|
||||||
|
"pe": pe,
|
||||||
|
"attn_mask": attn_mask},
|
||||||
|
{"original_block": block_wrap})
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_o = control.get("output")
|
||||||
|
if i < len(control_o):
|
||||||
|
add = control_o[i]
|
||||||
|
if add is not None:
|
||||||
|
img[:, txt.shape[1] :, ...] += add
|
||||||
|
|
||||||
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
final_mod = self.get_modulations(mod_vectors, "final")
|
||||||
|
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
patch_size = 2
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||||
@@ -23,7 +23,6 @@ from einops import rearrange, repeat
|
|||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
@@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
|
|||||||
return t_out
|
return t_out
|
||||||
|
|
||||||
|
|
||||||
def get_normalization(name: str, channels: int, weight_args={}):
|
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||||
if name == "I":
|
if name == "I":
|
||||||
return nn.Identity()
|
return nn.Identity()
|
||||||
elif name == "R":
|
elif name == "R":
|
||||||
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Normalization {name} not found")
|
raise ValueError(f"Normalization {name} not found")
|
||||||
|
|
||||||
@@ -120,15 +119,15 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.to_q = nn.Sequential(
|
self.to_q = nn.Sequential(
|
||||||
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[0], norm_dim),
|
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
self.to_k = nn.Sequential(
|
self.to_k = nn.Sequential(
|
||||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[1], norm_dim),
|
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
self.to_v = nn.Sequential(
|
self.to_v = nn.Sequential(
|
||||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[2], norm_dim),
|
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
|
|||||||
@@ -27,8 +27,6 @@ from torchvision import transforms
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
|
||||||
|
|
||||||
from .blocks import (
|
from .blocks import (
|
||||||
FinalLayer,
|
FinalLayer,
|
||||||
GeneralDITTransformerBlock,
|
GeneralDITTransformerBlock,
|
||||||
@@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
|
|||||||
|
|
||||||
if self.affline_emb_norm:
|
if self.affline_emb_norm:
|
||||||
logging.debug("Building affine embedding normalization layer")
|
logging.debug("Building affine embedding normalization layer")
|
||||||
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.affline_norm = nn.Identity()
|
self.affline_norm = nn.Identity()
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
|
|||||||
from .layers import (
|
from .layers import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
PatchEmbed,
|
PatchEmbed,
|
||||||
RMSNorm,
|
|
||||||
TimestepEmbedder,
|
TimestepEmbedder,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
|
|||||||
|
|
||||||
# Query and key normalization for stability.
|
# Query and key normalization for stability.
|
||||||
assert qk_norm
|
assert qk_norm
|
||||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||||
|
|
||||||
# Output layers. y features go back down from dim_x -> dim_y.
|
# Output layers. y features go back down from dim_x -> dim_y.
|
||||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||||
|
|||||||
@@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
|
|||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
|
||||||
|
|||||||
@@ -699,10 +699,13 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states_llama3=None,
|
encoder_hidden_states_llama3=None,
|
||||||
|
image_cond=None,
|
||||||
control = None,
|
control = None,
|
||||||
transformer_options = {},
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
|
if image_cond is not None:
|
||||||
|
x = torch.cat([x, image_cond], dim=-1)
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
timesteps = t
|
timesteps = t
|
||||||
pooled_embeds = y
|
pooled_embeds = y
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
|
||||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
from torch.utils import checkpoint
|
from torch.utils import checkpoint
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
|
|||||||
if norm_type == "layer":
|
if norm_type == "layer":
|
||||||
norm_layer = operations.LayerNorm
|
norm_layer = operations.LayerNorm
|
||||||
elif norm_type == "rms":
|
elif norm_type == "rms":
|
||||||
norm_layer = RMSNorm
|
norm_layer = operations.RMSNorm
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import math
|
import math
|
||||||
@@ -262,8 +261,8 @@ class CrossAttention(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
|
||||||
@@ -64,8 +64,8 @@ class JointAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if qk_norm:
|
if qk_norm:
|
||||||
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
else:
|
else:
|
||||||
self.q_norm = self.k_norm = nn.Identity()
|
self.q_norm = self.k_norm = nn.Identity()
|
||||||
|
|
||||||
@@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
|
|||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
self.modulation = modulation
|
self.modulation = modulation
|
||||||
if modulation:
|
if modulation:
|
||||||
@@ -431,7 +431,7 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||||
self.cap_embedder = nn.Sequential(
|
self.cap_embedder = nn.Sequential(
|
||||||
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
operation_settings.get("operations").Linear(
|
operation_settings.get("operations").Linear(
|
||||||
cap_feat_dim,
|
cap_feat_dim,
|
||||||
dim,
|
dim,
|
||||||
@@ -457,7 +457,7 @@ class NextDiT(nn.Module):
|
|||||||
for layer_id in range(n_layers)
|
for layer_id in range(n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||||
|
|
||||||
assert (dim // n_heads) == sum(axes_dims)
|
assert (dim // n_heads) == sum(axes_dims)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from einops import repeat
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
@@ -49,8 +48,8 @@ class WanSelfAttention(nn.Module):
|
|||||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, freqs):
|
def forward(self, x, freqs):
|
||||||
r"""
|
r"""
|
||||||
@@ -114,7 +113,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context, context_img_len):
|
def forward(self, x, context, context_img_len):
|
||||||
r"""
|
r"""
|
||||||
@@ -220,6 +219,34 @@ class WanAttentionBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cross_attn_type,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
eps=1e-6,
|
||||||
|
block_id=0,
|
||||||
|
operation_settings={}
|
||||||
|
):
|
||||||
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, c, x, **kwargs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
c = super().forward(c, **kwargs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
return c_skip, c
|
||||||
|
|
||||||
|
|
||||||
class Head(nn.Module):
|
class Head(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||||
@@ -395,6 +422,7 @@ class WanModel(torch.nn.Module):
|
|||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@@ -457,7 +485,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@@ -471,7 +499,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
@@ -496,3 +524,116 @@ class WanModel(torch.nn.Module):
|
|||||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||||
return u
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanModel(WanModel):
|
||||||
|
r"""
|
||||||
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_type='vace',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
flf_pos_embed_token_number=None,
|
||||||
|
image_model=None,
|
||||||
|
vace_layers=None,
|
||||||
|
vace_in_dim=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
# Vace
|
||||||
|
if vace_layers is not None:
|
||||||
|
self.vace_layers = vace_layers
|
||||||
|
self.vace_in_dim = vace_in_dim
|
||||||
|
# vace blocks
|
||||||
|
self.vace_blocks = nn.ModuleList([
|
||||||
|
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings)
|
||||||
|
for i in range(self.vace_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))}
|
||||||
|
# vace patch embeddings
|
||||||
|
self.vace_patch_embedding = operations.Conv3d(
|
||||||
|
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
context,
|
||||||
|
vace_context,
|
||||||
|
vace_strength=1.0,
|
||||||
|
clip_fea=None,
|
||||||
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||||
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||||
|
c = c.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
x_orig = x
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
|
ii = self.vace_layers_mapping.get(i, None)
|
||||||
|
if ii is not None:
|
||||||
|
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
x += c_skip * vace_strength
|
||||||
|
del c_skip
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
|||||||
328
comfy/lora.py
328
comfy/lora.py
@@ -20,6 +20,7 @@ from __future__ import annotations
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
|
import comfy.weight_adapter as weight_adapter
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True):
|
|||||||
dora_scale = lora[dora_scale_name]
|
dora_scale = lora[dora_scale_name]
|
||||||
loaded_keys.add(dora_scale_name)
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
reshape_name = "{}.reshape_weight".format(x)
|
for adapter_cls in weight_adapter.adapters:
|
||||||
reshape = None
|
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
|
||||||
if reshape_name in lora.keys():
|
if adapter is not None:
|
||||||
try:
|
patch_dict[to_load[x]] = adapter
|
||||||
reshape = lora[reshape_name].tolist()
|
loaded_keys.update(adapter.loaded_keys)
|
||||||
loaded_keys.add(reshape_name)
|
continue
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
||||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
||||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
||||||
mochi_lora = "{}.lora_B".format(x)
|
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
||||||
A_name = None
|
|
||||||
|
|
||||||
if regular_lora in lora.keys():
|
|
||||||
A_name = regular_lora
|
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
|
||||||
elif diffusers_lora in lora.keys():
|
|
||||||
A_name = diffusers_lora
|
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers2_lora in lora.keys():
|
|
||||||
A_name = diffusers2_lora
|
|
||||||
B_name = "{}.lora_A.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers3_lora in lora.keys():
|
|
||||||
A_name = diffusers3_lora
|
|
||||||
B_name = "{}.lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif mochi_lora in lora.keys():
|
|
||||||
A_name = mochi_lora
|
|
||||||
B_name = "{}.lora_A".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif transformers_lora in lora.keys():
|
|
||||||
A_name = transformers_lora
|
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
|
|
||||||
if A_name is not None:
|
|
||||||
mid = None
|
|
||||||
if mid_name is not None and mid_name in lora.keys():
|
|
||||||
mid = lora[mid_name]
|
|
||||||
loaded_keys.add(mid_name)
|
|
||||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
|
||||||
loaded_keys.add(A_name)
|
|
||||||
loaded_keys.add(B_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## loha
|
|
||||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
||||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
||||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
||||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
||||||
hada_t1_name = "{}.hada_t1".format(x)
|
|
||||||
hada_t2_name = "{}.hada_t2".format(x)
|
|
||||||
if hada_w1_a_name in lora.keys():
|
|
||||||
hada_t1 = None
|
|
||||||
hada_t2 = None
|
|
||||||
if hada_t1_name in lora.keys():
|
|
||||||
hada_t1 = lora[hada_t1_name]
|
|
||||||
hada_t2 = lora[hada_t2_name]
|
|
||||||
loaded_keys.add(hada_t1_name)
|
|
||||||
loaded_keys.add(hada_t2_name)
|
|
||||||
|
|
||||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
|
||||||
loaded_keys.add(hada_w1_a_name)
|
|
||||||
loaded_keys.add(hada_w1_b_name)
|
|
||||||
loaded_keys.add(hada_w2_a_name)
|
|
||||||
loaded_keys.add(hada_w2_b_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## lokr
|
|
||||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
||||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
||||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
||||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
||||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
||||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
||||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
||||||
|
|
||||||
lokr_w1 = None
|
|
||||||
if lokr_w1_name in lora.keys():
|
|
||||||
lokr_w1 = lora[lokr_w1_name]
|
|
||||||
loaded_keys.add(lokr_w1_name)
|
|
||||||
|
|
||||||
lokr_w2 = None
|
|
||||||
if lokr_w2_name in lora.keys():
|
|
||||||
lokr_w2 = lora[lokr_w2_name]
|
|
||||||
loaded_keys.add(lokr_w2_name)
|
|
||||||
|
|
||||||
lokr_w1_a = None
|
|
||||||
if lokr_w1_a_name in lora.keys():
|
|
||||||
lokr_w1_a = lora[lokr_w1_a_name]
|
|
||||||
loaded_keys.add(lokr_w1_a_name)
|
|
||||||
|
|
||||||
lokr_w1_b = None
|
|
||||||
if lokr_w1_b_name in lora.keys():
|
|
||||||
lokr_w1_b = lora[lokr_w1_b_name]
|
|
||||||
loaded_keys.add(lokr_w1_b_name)
|
|
||||||
|
|
||||||
lokr_w2_a = None
|
|
||||||
if lokr_w2_a_name in lora.keys():
|
|
||||||
lokr_w2_a = lora[lokr_w2_a_name]
|
|
||||||
loaded_keys.add(lokr_w2_a_name)
|
|
||||||
|
|
||||||
lokr_w2_b = None
|
|
||||||
if lokr_w2_b_name in lora.keys():
|
|
||||||
lokr_w2_b = lora[lokr_w2_b_name]
|
|
||||||
loaded_keys.add(lokr_w2_b_name)
|
|
||||||
|
|
||||||
lokr_t2 = None
|
|
||||||
if lokr_t2_name in lora.keys():
|
|
||||||
lokr_t2 = lora[lokr_t2_name]
|
|
||||||
loaded_keys.add(lokr_t2_name)
|
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
||||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
|
||||||
|
|
||||||
#glora
|
|
||||||
a1_name = "{}.a1.weight".format(x)
|
|
||||||
a2_name = "{}.a2.weight".format(x)
|
|
||||||
b1_name = "{}.b1.weight".format(x)
|
|
||||||
b2_name = "{}.b2.weight".format(x)
|
|
||||||
if a1_name in lora:
|
|
||||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
|
||||||
loaded_keys.add(a1_name)
|
|
||||||
loaded_keys.add(a2_name)
|
|
||||||
loaded_keys.add(b1_name)
|
|
||||||
loaded_keys.add(b2_name)
|
|
||||||
|
|
||||||
w_norm_name = "{}.w_norm".format(x)
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
b_norm_name = "{}.b_norm".format(x)
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
@@ -405,29 +279,16 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["transformer.{}".format(key_lora)] = k
|
key_map["transformer.{}".format(key_lora)] = k
|
||||||
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
|
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.HiDream):
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model."):
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
|
||||||
lora_diff *= alpha
|
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
|
||||||
weight_norm = (
|
|
||||||
weight_calc.transpose(0, 1)
|
|
||||||
.reshape(weight_calc.shape[1], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
|
||||||
weight[:] = weight_calc
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad a tensor to a new shape with zeros.
|
Pad a tensor to a new shape with zeros.
|
||||||
@@ -482,6 +343,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
|
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
||||||
|
if output is None:
|
||||||
|
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||||
|
else:
|
||||||
|
weight = output
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
continue
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
elif len(v) == 2:
|
elif len(v) == 2:
|
||||||
@@ -508,157 +379,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
|
||||||
dora_scale = v[4]
|
|
||||||
reshape = v[5]
|
|
||||||
|
|
||||||
if reshape is not None:
|
|
||||||
weight = pad_tensor_to_shape(weight, reshape)
|
|
||||||
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
old_glora = False
|
|
||||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
|
||||||
rank = v[0].shape[0]
|
|
||||||
old_glora = True
|
|
||||||
|
|
||||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
|
||||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
old_glora = False
|
|
||||||
rank = v[1].shape[0]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / rank
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
if old_glora:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
|
||||||
else:
|
|
||||||
if weight.dim() > 2:
|
|
||||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
else:
|
|
||||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
||||||
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ import comfy.ldm.lumina.model
|
|||||||
import comfy.ldm.wan.model
|
import comfy.ldm.wan.model
|
||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
import comfy.ldm.hidream.model
|
import comfy.ldm.hidream.model
|
||||||
|
import comfy.ldm.chroma.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@@ -786,8 +787,8 @@ class PixArt(BaseModel):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class Flux(BaseModel):
|
class Flux(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
@@ -1043,6 +1044,37 @@ class WAN21(BaseModel):
|
|||||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WAN21_Vace(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
noise_shape = list(noise.shape)
|
||||||
|
vace_frames = kwargs.get("vace_frames", None)
|
||||||
|
if vace_frames is None:
|
||||||
|
noise_shape[1] = 32
|
||||||
|
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
|
||||||
|
for i in range(0, vace_frames.shape[1], 16):
|
||||||
|
vace_frames = vace_frames.clone()
|
||||||
|
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
||||||
|
|
||||||
|
mask = kwargs.get("vace_mask", None)
|
||||||
|
if mask is None:
|
||||||
|
noise_shape[1] = 64
|
||||||
|
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
|
||||||
|
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
||||||
|
|
||||||
|
vace_strength = kwargs.get("vace_strength", 1.0)
|
||||||
|
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
@@ -1073,4 +1105,19 @@ class HiDream(BaseModel):
|
|||||||
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
||||||
if conditioning_llama3 is not None:
|
if conditioning_llama3 is not None:
|
||||||
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
||||||
|
image_cond = kwargs.get("concat_latent_image", None)
|
||||||
|
if image_cond is not None:
|
||||||
|
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Chroma(Flux):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
|
||||||
|
guidance = kwargs.get("guidance", 0)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -164,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
if in_key in state_dict_keys:
|
if in_key in state_dict_keys:
|
||||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||||
dit_config["out_channels"] = 16
|
dit_config["out_channels"] = 16
|
||||||
dit_config["vec_in_dim"] = 768
|
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||||
|
if vec_in_key in state_dict_keys:
|
||||||
|
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||||
dit_config["context_in_dim"] = 4096
|
dit_config["context_in_dim"] = 4096
|
||||||
dit_config["hidden_size"] = 3072
|
dit_config["hidden_size"] = 3072
|
||||||
dit_config["mlp_ratio"] = 4.0
|
dit_config["mlp_ratio"] = 4.0
|
||||||
@@ -174,6 +176,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["axes_dim"] = [16, 56, 56]
|
dit_config["axes_dim"] = [16, 56, 56]
|
||||||
dit_config["theta"] = 10000
|
dit_config["theta"] = 10000
|
||||||
dit_config["qkv_bias"] = True
|
dit_config["qkv_bias"] = True
|
||||||
|
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||||
|
dit_config["image_model"] = "chroma"
|
||||||
|
dit_config["in_channels"] = 64
|
||||||
|
dit_config["out_channels"] = 64
|
||||||
|
dit_config["in_dim"] = 64
|
||||||
|
dit_config["out_dim"] = 3072
|
||||||
|
dit_config["hidden_dim"] = 5120
|
||||||
|
dit_config["n_layers"] = 5
|
||||||
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
@@ -317,6 +328,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["cross_attn_norm"] = True
|
dit_config["cross_attn_norm"] = True
|
||||||
dit_config["eps"] = 1e-6
|
dit_config["eps"] = 1e-6
|
||||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
|
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "vace"
|
||||||
|
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -725,6 +725,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
if args.fp8_e5m2_unet:
|
if args.fp8_e5m2_unet:
|
||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
|
if args.fp8_e8m0fnu_unet:
|
||||||
|
return torch.float8_e8m0fnu
|
||||||
|
|
||||||
fp8_dtype = None
|
fp8_dtype = None
|
||||||
if weight_dtype in FLOAT8_TYPES:
|
if weight_dtype in FLOAT8_TYPES:
|
||||||
@@ -937,13 +939,59 @@ def force_channels_last():
|
|||||||
#TODO
|
#TODO
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
|
||||||
|
STREAMS = {}
|
||||||
|
NUM_STREAMS = 1
|
||||||
|
if args.async_offload:
|
||||||
|
NUM_STREAMS = 2
|
||||||
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||||
|
|
||||||
|
stream_counters = {}
|
||||||
|
def get_offload_stream(device):
|
||||||
|
stream_counter = stream_counters.get(device, 0)
|
||||||
|
if NUM_STREAMS <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if device in STREAMS:
|
||||||
|
ss = STREAMS[device]
|
||||||
|
s = ss[stream_counter]
|
||||||
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
|
if is_device_cuda(device):
|
||||||
|
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||||
|
stream_counters[device] = stream_counter
|
||||||
|
return s
|
||||||
|
elif is_device_cuda(device):
|
||||||
|
ss = []
|
||||||
|
for k in range(NUM_STREAMS):
|
||||||
|
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||||
|
STREAMS[device] = ss
|
||||||
|
s = ss[stream_counter]
|
||||||
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
|
stream_counters[device] = stream_counter
|
||||||
|
return s
|
||||||
|
return None
|
||||||
|
|
||||||
|
def sync_stream(device, stream):
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
if is_device_cuda(device):
|
||||||
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
return weight
|
return weight
|
||||||
|
if stream is not None:
|
||||||
|
with stream:
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
if stream is not None:
|
||||||
|
with stream:
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
else:
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
return r
|
return r
|
||||||
|
|||||||
@@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
self.linear_start = linear_start
|
self.linear_start = linear_start
|
||||||
self.linear_end = linear_end
|
self.linear_end = linear_end
|
||||||
|
self.zsnr = zsnr
|
||||||
|
|
||||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
if zsnr:
|
if self.zsnr:
|
||||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||||
|
|
||||||
self.set_sigmas(sigmas)
|
self.set_sigmas(sigmas)
|
||||||
|
|||||||
16
comfy/ops.py
16
comfy/ops.py
@@ -22,6 +22,7 @@ import comfy.model_management
|
|||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
|
import contextlib
|
||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
|
|
||||||
@@ -37,20 +38,31 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
if offload_stream is not None:
|
||||||
|
wf_context = offload_stream
|
||||||
|
else:
|
||||||
|
wf_context = contextlib.nullcontext()
|
||||||
|
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
has_function = len(s.bias_function) > 0
|
has_function = len(s.bias_function) > 0
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||||
|
|
||||||
if has_function:
|
if has_function:
|
||||||
|
with wf_context:
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
has_function = len(s.weight_function) > 0
|
has_function = len(s.weight_function) > 0
|
||||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||||
if has_function:
|
if has_function:
|
||||||
|
with wf_context:
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
|
|||||||
@@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
|
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
47
comfy/sd.py
47
comfy/sd.py
@@ -120,6 +120,7 @@ class CLIP:
|
|||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
self.tokenizer_options = {}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@@ -127,6 +128,7 @@ class CLIP:
|
|||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
|
n.tokenizer_options = self.tokenizer_options.copy()
|
||||||
n.use_clip_schedule = self.use_clip_schedule
|
n.use_clip_schedule = self.use_clip_schedule
|
||||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||||
return n
|
return n
|
||||||
@@ -134,10 +136,18 @@ class CLIP:
|
|||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||||
|
|
||||||
|
def set_tokenizer_option(self, option_name, value):
|
||||||
|
self.tokenizer_options[option_name] = value
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def tokenize(self, text, return_word_ids=False, **kwargs):
|
def tokenize(self, text, return_word_ids=False, **kwargs):
|
||||||
|
tokenizer_options = kwargs.get("tokenizer_options", {})
|
||||||
|
if len(self.tokenizer_options) > 0:
|
||||||
|
tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
|
||||||
|
if len(tokenizer_options) > 0:
|
||||||
|
kwargs["tokenizer_options"] = tokenizer_options
|
||||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
|
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
|
||||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||||
@@ -703,6 +713,8 @@ class CLIPType(Enum):
|
|||||||
COSMOS = 11
|
COSMOS = 11
|
||||||
LUMINA2 = 12
|
LUMINA2 = 12
|
||||||
WAN = 13
|
WAN = 13
|
||||||
|
HIDREAM = 14
|
||||||
|
CHROMA = 15
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@@ -791,6 +803,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.SD3:
|
elif clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
@@ -804,13 +819,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.LTXV:
|
elif clip_type == CLIPType.LTXV:
|
||||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||||
elif clip_type == CLIPType.PIXART:
|
elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA:
|
||||||
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
|
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
|
||||||
elif clip_type == CLIPType.WAN:
|
elif clip_type == CLIPType.WAN:
|
||||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
@@ -827,10 +846,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
@@ -848,6 +875,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
# Detect
|
||||||
|
hidream_dualclip_classes = []
|
||||||
|
for hidream_te in clip_data:
|
||||||
|
te_model = detect_te_model(hidream_te)
|
||||||
|
hidream_dualclip_classes.append(te_model)
|
||||||
|
|
||||||
|
clip_l = TEModel.CLIP_L in hidream_dualclip_classes
|
||||||
|
clip_g = TEModel.CLIP_G in hidream_dualclip_classes
|
||||||
|
t5 = TEModel.T5_XXL in hidream_dualclip_classes
|
||||||
|
llama = TEModel.LLAMA3_8 in hidream_dualclip_classes
|
||||||
|
|
||||||
|
# Initialize t5xxl_detect and llama_detect kwargs if needed
|
||||||
|
t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
|
||||||
|
llama_kwargs = llama_detect(clip_data) if llama else {}
|
||||||
|
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
|||||||
@@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.end_token = None
|
self.end_token = None
|
||||||
|
self.min_padding = min_padding
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
self.tokenizer_adds_end_token = has_end_token
|
self.tokenizer_adds_end_token = has_end_token
|
||||||
@@ -518,13 +519,15 @@ class SDTokenizer:
|
|||||||
return (embed, leftover)
|
return (embed, leftover)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||||
'''
|
'''
|
||||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||||
'''
|
'''
|
||||||
|
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
||||||
|
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
parsed_weights = token_weights(text, 1.0)
|
parsed_weights = token_weights(text, 1.0)
|
||||||
@@ -603,10 +606,12 @@ class SDTokenizer:
|
|||||||
#fill last batch
|
#fill last batch
|
||||||
if self.end_token is not None:
|
if self.end_token is not None:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if min_padding is not None:
|
||||||
|
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
||||||
|
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
if self.min_length is not None and len(batch) < self.min_length:
|
if min_length is not None and len(batch) < min_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
||||||
|
|
||||||
if not return_word_ids:
|
if not return_word_ids:
|
||||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||||
@@ -634,7 +639,7 @@ class SD1Tokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
|||||||
@@ -28,8 +28,8 @@ class SDXLTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
|||||||
@@ -987,6 +987,20 @@ class WAN21_FunControl2V(WAN21_T2V):
|
|||||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_Vace(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "vace",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.memory_usage_factor = 1.2 * self.memory_usage_factor
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@@ -1054,7 +1068,34 @@ class HiDream(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return None # TODO
|
return None # TODO
|
||||||
|
|
||||||
|
class Chroma(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "chroma",
|
||||||
|
}
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
unet_extra_config = {
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Flux
|
||||||
|
|
||||||
|
memory_usage_factor = 3.2
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Chroma(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ class FluxTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ class HiDreamTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
||||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
@@ -109,14 +109,18 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
if self.t5xxl is not None:
|
if self.t5xxl is not None:
|
||||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||||
t5_out, t5_pooled = t5_output[:2]
|
t5_out, t5_pooled = t5_output[:2]
|
||||||
|
else:
|
||||||
|
t5_out = None
|
||||||
|
|
||||||
if self.llama is not None:
|
if self.llama is not None:
|
||||||
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||||
ll_out, ll_pooled = ll_output[:2]
|
ll_out, ll_pooled = ll_output[:2]
|
||||||
ll_out = ll_out[:, 1:]
|
ll_out = ll_out[:, 1:]
|
||||||
|
else:
|
||||||
|
ll_out = None
|
||||||
|
|
||||||
if t5_out is None:
|
if t5_out is None:
|
||||||
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
|
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
if ll_out is None:
|
if ll_out is None:
|
||||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|||||||
@@ -49,13 +49,13 @@ class HunyuanVideoTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
|
||||||
if llama_template is None:
|
if llama_template is None:
|
||||||
llama_text = self.llama_template.format(text)
|
llama_text = self.llama_template.format(text)
|
||||||
else:
|
else:
|
||||||
llama_text = llama_template.format(text)
|
llama_text = llama_template.format(text)
|
||||||
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
|
||||||
embed_count = 0
|
embed_count = 0
|
||||||
for r in llama_text_tokens:
|
for r in llama_text_tokens:
|
||||||
for i in range(len(r)):
|
for i in range(len(r)):
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ class HyditTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
|||||||
@@ -45,9 +45,9 @@ class SD3Tokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
class SPieceTokenizer:
|
class SPieceTokenizer:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -15,6 +16,8 @@ class SPieceTokenizer:
|
|||||||
if isinstance(tokenizer_path, bytes):
|
if isinstance(tokenizer_path, bytes):
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||||
else:
|
else:
|
||||||
|
if not os.path.isfile(tokenizer_path):
|
||||||
|
raise ValueError("invalid tokenizer")
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
|
|||||||
17
comfy/weight_adapter/__init__.py
Normal file
17
comfy/weight_adapter/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from .base import WeightAdapterBase
|
||||||
|
from .lora import LoRAAdapter
|
||||||
|
from .loha import LoHaAdapter
|
||||||
|
from .lokr import LoKrAdapter
|
||||||
|
from .glora import GLoRAAdapter
|
||||||
|
from .oft import OFTAdapter
|
||||||
|
from .boft import BOFTAdapter
|
||||||
|
|
||||||
|
|
||||||
|
adapters: list[type[WeightAdapterBase]] = [
|
||||||
|
LoRAAdapter,
|
||||||
|
LoHaAdapter,
|
||||||
|
LoKrAdapter,
|
||||||
|
GLoRAAdapter,
|
||||||
|
OFTAdapter,
|
||||||
|
BOFTAdapter,
|
||||||
|
]
|
||||||
104
comfy/weight_adapter/base.py
Normal file
104
comfy/weight_adapter/base.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdapterBase:
|
||||||
|
name: str
|
||||||
|
loaded_keys: set[str]
|
||||||
|
weights: list[torch.Tensor]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_train(self) -> "WeightAdapterTrainBase":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdapterTrainBase(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# [TODO] Collaborate with LoRA training PR #7032
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
|
||||||
|
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
||||||
|
if wd_on_output_axis:
|
||||||
|
weight_norm = (
|
||||||
|
weight.reshape(weight.shape[0], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
115
comfy/weight_adapter/boft.py
Normal file
115
comfy/weight_adapter/boft.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class BOFTAdapter(WeightAdapterBase):
|
||||||
|
name = "boft"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["BOFTAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
blocks_name = "{}.oft_blocks".format(x)
|
||||||
|
rescale_name = "{}.rescale".format(x)
|
||||||
|
|
||||||
|
blocks = None
|
||||||
|
if blocks_name in lora.keys():
|
||||||
|
blocks = lora[blocks_name]
|
||||||
|
if blocks.ndim == 4:
|
||||||
|
loaded_keys.add(blocks_name)
|
||||||
|
else:
|
||||||
|
blocks = None
|
||||||
|
if blocks is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rescale = None
|
||||||
|
if rescale_name in lora.keys():
|
||||||
|
rescale = lora[rescale_name]
|
||||||
|
loaded_keys.add(rescale_name)
|
||||||
|
|
||||||
|
weights = (blocks, rescale, alpha, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
blocks = v[0]
|
||||||
|
rescale = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
dora_scale = v[3]
|
||||||
|
|
||||||
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
if rescale is not None:
|
||||||
|
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
boft_m, block_num, boft_b, *_ = blocks.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get r
|
||||||
|
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = blocks - blocks.transpose(-1, -2)
|
||||||
|
normed_q = q
|
||||||
|
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > alpha:
|
||||||
|
normed_q = q * alpha / q_norm
|
||||||
|
# use float() to prevent unsupported type in .inverse()
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
r = r.to(weight)
|
||||||
|
inp = org = weight
|
||||||
|
|
||||||
|
r_b = boft_b//2
|
||||||
|
for i in range(boft_m):
|
||||||
|
bi = r[i]
|
||||||
|
g = 2
|
||||||
|
k = 2**i * r_b
|
||||||
|
if strength != 1:
|
||||||
|
bi = bi * strength + (1-strength) * I
|
||||||
|
inp = (
|
||||||
|
inp.unflatten(0, (-1, g, k))
|
||||||
|
.transpose(1, 2)
|
||||||
|
.flatten(0, 2)
|
||||||
|
.unflatten(0, (-1, boft_b))
|
||||||
|
)
|
||||||
|
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
||||||
|
inp = (
|
||||||
|
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale is not None:
|
||||||
|
inp = inp * rescale
|
||||||
|
|
||||||
|
lora_diff = inp - org
|
||||||
|
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function((strength * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
93
comfy/weight_adapter/glora.py
Normal file
93
comfy/weight_adapter/glora.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class GLoRAAdapter(WeightAdapterBase):
|
||||||
|
name = "glora"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["GLoRAAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
a1_name = "{}.a1.weight".format(x)
|
||||||
|
a2_name = "{}.a2.weight".format(x)
|
||||||
|
b1_name = "{}.b1.weight".format(x)
|
||||||
|
b2_name = "{}.b2.weight".format(x)
|
||||||
|
if a1_name in lora:
|
||||||
|
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
||||||
|
loaded_keys.add(a1_name)
|
||||||
|
loaded_keys.add(a2_name)
|
||||||
|
loaded_keys.add(b1_name)
|
||||||
|
loaded_keys.add(b2_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
100
comfy/weight_adapter/loha.py
Normal file
100
comfy/weight_adapter/loha.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
|
name = "loha"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoHaAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
133
comfy/weight_adapter/lokr.py
Normal file
133
comfy/weight_adapter/lokr.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
|
name = "lokr"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoKrAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
142
comfy/weight_adapter/lora.py
Normal file
142
comfy/weight_adapter/lora.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAAdapter(WeightAdapterBase):
|
||||||
|
name = "lora"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoRAAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
|
||||||
|
reshape_name = "{}.reshape_weight".format(x)
|
||||||
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
|
mochi_lora = "{}.lora_B".format(x)
|
||||||
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers3_lora in lora.keys():
|
||||||
|
A_name = diffusers3_lora
|
||||||
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif mochi_lora in lora.keys():
|
||||||
|
A_name = mochi_lora
|
||||||
|
B_name = "{}.lora_A".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
|
mid = None
|
||||||
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
reshape = None
|
||||||
|
if reshape_name in lora.keys():
|
||||||
|
try:
|
||||||
|
reshape = lora[reshape_name].tolist()
|
||||||
|
loaded_keys.add(reshape_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
mat1 = comfy.model_management.cast_to_device(
|
||||||
|
v[0], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(
|
||||||
|
v[1], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
dora_scale = v[4]
|
||||||
|
reshape = v[5]
|
||||||
|
|
||||||
|
if reshape is not None:
|
||||||
|
weight = pad_tensor_to_shape(weight, reshape)
|
||||||
|
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(
|
||||||
|
v[3], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = (
|
||||||
|
torch.mm(
|
||||||
|
mat2.transpose(0, 1).flatten(start_dim=1),
|
||||||
|
mat3.transpose(0, 1).flatten(start_dim=1),
|
||||||
|
)
|
||||||
|
.reshape(final_shape)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(
|
||||||
|
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
||||||
|
).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
96
comfy/weight_adapter/oft.py
Normal file
96
comfy/weight_adapter/oft.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class OFTAdapter(WeightAdapterBase):
|
||||||
|
name = "oft"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["OFTAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
blocks_name = "{}.oft_blocks".format(x)
|
||||||
|
rescale_name = "{}.rescale".format(x)
|
||||||
|
|
||||||
|
blocks = None
|
||||||
|
if blocks_name in lora.keys():
|
||||||
|
blocks = lora[blocks_name]
|
||||||
|
if blocks.ndim == 3:
|
||||||
|
loaded_keys.add(blocks_name)
|
||||||
|
else:
|
||||||
|
blocks = None
|
||||||
|
if blocks is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rescale = None
|
||||||
|
if rescale_name in lora.keys():
|
||||||
|
rescale = lora[rescale_name]
|
||||||
|
loaded_keys.add(rescale_name)
|
||||||
|
|
||||||
|
weights = (blocks, rescale, alpha, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
blocks = v[0]
|
||||||
|
rescale = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
dora_scale = v[3]
|
||||||
|
|
||||||
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
if rescale is not None:
|
||||||
|
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
block_num, block_size, *_ = blocks.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get r
|
||||||
|
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = blocks - blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if alpha > 0: # alpha in oft/boft is for constraint
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > alpha:
|
||||||
|
normed_q = q * alpha / q_norm
|
||||||
|
# use float() to prevent unsupported type in .inverse()
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
r = r.to(weight)
|
||||||
|
_, *shape = weight.shape
|
||||||
|
lora_diff = torch.einsum(
|
||||||
|
"k n m, k n ... -> k m ...",
|
||||||
|
(r * strength) - strength * I,
|
||||||
|
weight.view(block_num, block_size, *shape),
|
||||||
|
).view(-1, *shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function((strength * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
8
comfy_api/input/__init__.py
Normal file
8
comfy_api/input/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from .basic_types import ImageInput, AudioInput
|
||||||
|
from .video_types import VideoInput
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImageInput",
|
||||||
|
"AudioInput",
|
||||||
|
"VideoInput",
|
||||||
|
]
|
||||||
20
comfy_api/input/basic_types.py
Normal file
20
comfy_api/input/basic_types.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
ImageInput = torch.Tensor
|
||||||
|
"""
|
||||||
|
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||||
|
"""
|
||||||
|
|
||||||
|
class AudioInput(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict representing audio input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
waveform: torch.Tensor
|
||||||
|
"""
|
||||||
|
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||||
|
"""
|
||||||
|
|
||||||
|
sample_rate: int
|
||||||
|
|
||||||
45
comfy_api/input/video_types.py
Normal file
45
comfy_api/input/video_types.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
class VideoInput(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for video input types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
"""
|
||||||
|
Abstract method to get the video components (images, audio, and frame rate).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VideoComponents containing images, audio, and frame rate
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Abstract method to save the video input to a file.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Provide a default implementation, but subclasses can provide optimized versions
|
||||||
|
# if possible.
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the dimensions of the video input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height)
|
||||||
|
"""
|
||||||
|
components = self.get_components()
|
||||||
|
return components.images.shape[2], components.images.shape[1]
|
||||||
|
|
||||||
7
comfy_api/input_impl/__init__.py
Normal file
7
comfy_api/input_impl/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .video_types import VideoFromFile, VideoFromComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Implementations
|
||||||
|
"VideoFromFile",
|
||||||
|
"VideoFromComponents",
|
||||||
|
]
|
||||||
271
comfy_api/input_impl/video_types.py
Normal file
271
comfy_api/input_impl/video_types.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from av.container import InputContainer
|
||||||
|
from av.subtitles.stream import SubtitleStream
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
from comfy_api.input import AudioInput
|
||||||
|
import av
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from comfy_api.input import VideoInput
|
||||||
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
|
||||||
|
def container_to_output_format(container_format: str | None) -> str | None:
|
||||||
|
"""
|
||||||
|
A container's `format` may be a comma-separated list of formats.
|
||||||
|
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||||
|
However, writing to a file/stream with `av.open` requires a single format,
|
||||||
|
or `None` to auto-detect.
|
||||||
|
"""
|
||||||
|
if not container_format:
|
||||||
|
return None # Auto-detect
|
||||||
|
|
||||||
|
if "," not in container_format:
|
||||||
|
return container_format
|
||||||
|
|
||||||
|
formats = container_format.split(",")
|
||||||
|
return formats[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_write_kwargs(
|
||||||
|
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||||
|
) -> dict:
|
||||||
|
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||||
|
open_kwargs = {
|
||||||
|
"mode": "w",
|
||||||
|
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||||
|
"options": {"movflags": "use_metadata_tags"},
|
||||||
|
}
|
||||||
|
|
||||||
|
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||||
|
if is_write_to_buffer:
|
||||||
|
# Set output format explicitly, since it cannot be inferred from file extension
|
||||||
|
if to_format == VideoContainer.AUTO:
|
||||||
|
to_format = container_format.lower()
|
||||||
|
elif isinstance(to_format, str):
|
||||||
|
to_format = to_format.lower()
|
||||||
|
open_kwargs["format"] = container_to_output_format(to_format)
|
||||||
|
|
||||||
|
return open_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class VideoFromFile(VideoInput):
|
||||||
|
"""
|
||||||
|
Class representing video input from a file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file: str | io.BytesIO):
|
||||||
|
"""
|
||||||
|
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||||
|
containing the file contents.
|
||||||
|
"""
|
||||||
|
self.__file = file
|
||||||
|
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the dimensions of the video input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height)
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
for stream in container.streams:
|
||||||
|
if stream.type == 'video':
|
||||||
|
assert isinstance(stream, av.VideoStream)
|
||||||
|
return stream.width, stream.height
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||||
|
# Get video frames
|
||||||
|
frames = []
|
||||||
|
for frame in container.decode(video=0):
|
||||||
|
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||||
|
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||||
|
frames.append(img)
|
||||||
|
|
||||||
|
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||||
|
|
||||||
|
# Get frame rate
|
||||||
|
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||||
|
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||||
|
|
||||||
|
# Get audio if available
|
||||||
|
audio = None
|
||||||
|
try:
|
||||||
|
container.seek(0) # Reset the container to the beginning
|
||||||
|
for stream in container.streams:
|
||||||
|
if stream.type != 'audio':
|
||||||
|
continue
|
||||||
|
assert isinstance(stream, av.AudioStream)
|
||||||
|
audio_frames = []
|
||||||
|
for packet in container.demux(stream):
|
||||||
|
for frame in packet.decode():
|
||||||
|
assert isinstance(frame, av.AudioFrame)
|
||||||
|
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||||
|
if len(audio_frames) > 0:
|
||||||
|
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||||
|
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||||
|
audio = AudioInput({
|
||||||
|
"waveform": audio_tensor,
|
||||||
|
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||||
|
})
|
||||||
|
except StopIteration:
|
||||||
|
pass # No audio stream
|
||||||
|
|
||||||
|
metadata = container.metadata
|
||||||
|
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||||
|
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
return self.get_components_internal(container)
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str | io.BytesIO,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
container_format = container.format.name
|
||||||
|
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||||
|
reuse_streams = True
|
||||||
|
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||||
|
reuse_streams = False
|
||||||
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
|
reuse_streams = False
|
||||||
|
|
||||||
|
if not reuse_streams:
|
||||||
|
components = self.get_components_internal(container)
|
||||||
|
video = VideoFromComponents(components)
|
||||||
|
return video.save_to(
|
||||||
|
path,
|
||||||
|
format=format,
|
||||||
|
codec=codec,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
streams = container.streams
|
||||||
|
|
||||||
|
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||||
|
with av.open(path, **open_kwargs) as output_container:
|
||||||
|
# Copy over the original metadata
|
||||||
|
for key, value in container.metadata.items():
|
||||||
|
if metadata is None or key not in metadata:
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
# Add our new metadata
|
||||||
|
if metadata is not None:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
else:
|
||||||
|
output_container.metadata[key] = json.dumps(value)
|
||||||
|
|
||||||
|
# Add streams to the new container
|
||||||
|
stream_map = {}
|
||||||
|
for stream in streams:
|
||||||
|
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||||
|
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||||
|
stream_map[stream] = out_stream
|
||||||
|
|
||||||
|
# Write packets to the new container
|
||||||
|
for packet in container.demux():
|
||||||
|
if packet.stream in stream_map and packet.dts is not None:
|
||||||
|
packet.stream = stream_map[packet.stream]
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
class VideoFromComponents(VideoInput):
|
||||||
|
"""
|
||||||
|
Class representing video input from tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, components: VideoComponents):
|
||||||
|
self.__components = components
|
||||||
|
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
return VideoComponents(
|
||||||
|
images=self.__components.images,
|
||||||
|
audio=self.__components.audio,
|
||||||
|
frame_rate=self.__components.frame_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
|
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||||
|
# Add metadata before writing any streams
|
||||||
|
if metadata is not None:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
output.metadata[key] = json.dumps(value)
|
||||||
|
|
||||||
|
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||||
|
# Create a video stream
|
||||||
|
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||||
|
video_stream.width = self.__components.images.shape[2]
|
||||||
|
video_stream.height = self.__components.images.shape[1]
|
||||||
|
video_stream.pix_fmt = 'yuv420p'
|
||||||
|
|
||||||
|
# Create an audio stream
|
||||||
|
audio_sample_rate = 1
|
||||||
|
audio_stream: Optional[av.AudioStream] = None
|
||||||
|
if self.__components.audio:
|
||||||
|
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||||
|
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||||
|
audio_stream.sample_rate = audio_sample_rate
|
||||||
|
audio_stream.format = 'fltp'
|
||||||
|
|
||||||
|
# Encode video
|
||||||
|
for i, frame in enumerate(self.__components.images):
|
||||||
|
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||||
|
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||||
|
packet = video_stream.encode(frame)
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Flush video
|
||||||
|
packet = video_stream.encode(None)
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
if audio_stream and self.__components.audio:
|
||||||
|
# Encode audio
|
||||||
|
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||||
|
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||||
|
for i in range(num_frames):
|
||||||
|
start = i * samples_per_frame
|
||||||
|
end = start + samples_per_frame
|
||||||
|
# TODO(Feature) - Add support for stereo audio
|
||||||
|
chunk = (
|
||||||
|
self.__components.audio["waveform"][0, 0, start:end]
|
||||||
|
.unsqueeze(0)
|
||||||
|
.contiguous()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||||
|
audio_frame.sample_rate = audio_sample_rate
|
||||||
|
audio_frame.pts = i * samples_per_frame
|
||||||
|
for packet in audio_stream.encode(audio_frame):
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Flush audio
|
||||||
|
for packet in audio_stream.encode(None):
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
8
comfy_api/util/__init__.py
Normal file
8
comfy_api/util/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Utility Types
|
||||||
|
"VideoContainer",
|
||||||
|
"VideoCodec",
|
||||||
|
"VideoComponents",
|
||||||
|
]
|
||||||
51
comfy_api/util/video_types.py
Normal file
51
comfy_api/util/video_types.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
from comfy_api.input import ImageInput, AudioInput
|
||||||
|
|
||||||
|
class VideoCodec(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
H264 = "h264"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_input(cls) -> list[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of codec names that can be used as node input.
|
||||||
|
"""
|
||||||
|
return [member.value for member in cls]
|
||||||
|
|
||||||
|
class VideoContainer(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
MP4 = "mp4"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_input(cls) -> list[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of container names that can be used as node input.
|
||||||
|
"""
|
||||||
|
return [member.value for member in cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_extension(cls, value) -> str:
|
||||||
|
"""
|
||||||
|
Returns the file extension for the container.
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = cls(value)
|
||||||
|
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||||
|
return "mp4"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoComponents:
|
||||||
|
"""
|
||||||
|
Dataclass representing the components of a video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
images: ImageInput
|
||||||
|
frame_rate: Fraction
|
||||||
|
audio: Optional[AudioInput] = None
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
41
comfy_api_nodes/README.md
Normal file
41
comfy_api_nodes/README.md
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# ComfyUI API Nodes
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes).
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
While developing, you should be testing against the Staging environment. To test against staging:
|
||||||
|
|
||||||
|
**Install ComfyUI_frontend**
|
||||||
|
|
||||||
|
Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
|
||||||
|
|
||||||
|
> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
||||||
|
```
|
||||||
|
|
||||||
|
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
||||||
|
|
||||||
|
### Redocly Instructions
|
||||||
|
|
||||||
|
**Tip**
|
||||||
|
When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
|
||||||
|
|
||||||
|
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download the OpenAPI file from prod server.
|
||||||
|
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
||||||
|
|
||||||
|
# Filter out unneeded API definitions.
|
||||||
|
npm install -g @redocly/cli
|
||||||
|
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
|
||||||
|
|
||||||
|
# Generate the pydantic datamodels for validation.
|
||||||
|
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||||
|
|
||||||
|
```
|
||||||
0
comfy_api_nodes/__init__.py
Normal file
0
comfy_api_nodes/__init__.py
Normal file
575
comfy_api_nodes/apinode_utils.py
Normal file
575
comfy_api_nodes/apinode_utils.py
Normal file
@@ -0,0 +1,575 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from comfy.utils import common_upscale
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
from comfy_api.util import VideoContainer, VideoCodec
|
||||||
|
from comfy_api.input.video_types import VideoInput
|
||||||
|
from comfy_api.input.basic_types import AudioInput
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiClient,
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
UploadRequest,
|
||||||
|
UploadResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import base64
|
||||||
|
import uuid
|
||||||
|
from io import BytesIO
|
||||||
|
import av
|
||||||
|
|
||||||
|
|
||||||
|
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||||
|
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_url: The URL of the video to download.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Comfy node `VIDEO` output.
|
||||||
|
"""
|
||||||
|
video_io = download_url_to_bytesio(video_url, timeout)
|
||||||
|
if video_io is None:
|
||||||
|
error_msg = f"Failed to download video from {video_url}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
return VideoFromFile(video_io)
|
||||||
|
|
||||||
|
|
||||||
|
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||||
|
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||||
|
samples = image.movedim(-1, 1)
|
||||||
|
total = int(total_pixels)
|
||||||
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||||
|
if scale_by >= 1:
|
||||||
|
return image
|
||||||
|
width = round(samples.shape[3] * scale_by)
|
||||||
|
height = round(samples.shape[2] * scale_by)
|
||||||
|
|
||||||
|
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||||
|
s = s.movedim(1, -1)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
|
||||||
|
"""Validates and casts a response to a torch.Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response to validate and cast.
|
||||||
|
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torch.Tensor representing the image (1, H, W, C).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the response is not valid.
|
||||||
|
"""
|
||||||
|
# validate raw JSON response
|
||||||
|
data = response.data
|
||||||
|
if not data or len(data) == 0:
|
||||||
|
raise ValueError("No images returned from API endpoint")
|
||||||
|
|
||||||
|
# Initialize list to store image tensors
|
||||||
|
image_tensors: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
# Process each image in the data array
|
||||||
|
for image_data in data:
|
||||||
|
image_url = image_data.url
|
||||||
|
b64_data = image_data.b64_json
|
||||||
|
|
||||||
|
if not image_url and not b64_data:
|
||||||
|
raise ValueError("No image was generated in the response")
|
||||||
|
|
||||||
|
if b64_data:
|
||||||
|
img_data = base64.b64decode(b64_data)
|
||||||
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
|
||||||
|
elif image_url:
|
||||||
|
img_response = requests.get(image_url, timeout=timeout)
|
||||||
|
if img_response.status_code != 200:
|
||||||
|
raise ValueError("Failed to download the image")
|
||||||
|
img = Image.open(io.BytesIO(img_response.content))
|
||||||
|
|
||||||
|
img = img.convert("RGBA")
|
||||||
|
|
||||||
|
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||||
|
img_array = np.array(img).astype(np.float32) / 255.0
|
||||||
|
img_tensor = torch.from_numpy(img_array)
|
||||||
|
|
||||||
|
# Add to list of tensors
|
||||||
|
image_tensors.append(img_tensor)
|
||||||
|
|
||||||
|
return torch.stack(image_tensors, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_aspect_ratio(
|
||||||
|
aspect_ratio: str,
|
||||||
|
minimum_ratio: float,
|
||||||
|
maximum_ratio: float,
|
||||||
|
minimum_ratio_str: str,
|
||||||
|
maximum_ratio_str: str,
|
||||||
|
) -> float:
|
||||||
|
"""Validates and casts an aspect ratio string to a float.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aspect_ratio: The aspect ratio string to validate.
|
||||||
|
minimum_ratio: The minimum aspect ratio.
|
||||||
|
maximum_ratio: The maximum aspect ratio.
|
||||||
|
minimum_ratio_str: The minimum aspect ratio string.
|
||||||
|
maximum_ratio_str: The maximum aspect ratio string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The validated and cast aspect ratio.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the aspect ratio is not valid.
|
||||||
|
"""
|
||||||
|
# get ratio values
|
||||||
|
numbers = aspect_ratio.split(":")
|
||||||
|
if len(numbers) != 2:
|
||||||
|
raise TypeError(
|
||||||
|
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
numerator = int(numbers[0])
|
||||||
|
denominator = int(numbers[1])
|
||||||
|
except ValueError as exc:
|
||||||
|
raise TypeError(
|
||||||
|
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
||||||
|
) from exc
|
||||||
|
calculated_ratio = numerator / denominator
|
||||||
|
# if not close to minimum and maximum, check bounds
|
||||||
|
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
||||||
|
calculated_ratio, maximum_ratio
|
||||||
|
):
|
||||||
|
if calculated_ratio < minimum_ratio:
|
||||||
|
raise TypeError(
|
||||||
|
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||||
|
)
|
||||||
|
elif calculated_ratio > maximum_ratio:
|
||||||
|
raise TypeError(
|
||||||
|
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||||
|
)
|
||||||
|
return aspect_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def mimetype_to_extension(mime_type: str) -> str:
|
||||||
|
"""Converts a MIME type to a file extension."""
|
||||||
|
return mime_type.split("/")[-1].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||||
|
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to download.
|
||||||
|
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BytesIO object containing the downloaded content.
|
||||||
|
"""
|
||||||
|
response = requests.get(url, stream=True, timeout=timeout)
|
||||||
|
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||||
|
return BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
|
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||||
|
"""Converts image data from BytesIO to a torch.Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytesio: BytesIO object containing the image data.
|
||||||
|
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torch.Tensor representing the image (1, H, W, C).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||||
|
ValueError: If the specified mode is invalid.
|
||||||
|
"""
|
||||||
|
image = Image.open(image_bytesio)
|
||||||
|
image = image.convert(mode)
|
||||||
|
image_array = np.array(image).astype(np.float32) / 255.0
|
||||||
|
return torch.from_numpy(image_array).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||||
|
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||||
|
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||||
|
return bytesio_to_image_tensor(image_bytesio)
|
||||||
|
|
||||||
|
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||||
|
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||||
|
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||||
|
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||||
|
if len(image.shape) > 3:
|
||||||
|
image = image[0]
|
||||||
|
# TODO: remove alpha if not allowed and present
|
||||||
|
input_tensor = image.cpu()
|
||||||
|
input_tensor = downscale_image_tensor(
|
||||||
|
input_tensor.unsqueeze(0), total_pixels=total_pixels
|
||||||
|
).squeeze()
|
||||||
|
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(image_np)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||||
|
"""Converts a PIL Image to a BytesIO object."""
|
||||||
|
if not mime_type:
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||||
|
pil_format = mime_type.split("/")[-1].upper()
|
||||||
|
if pil_format == "JPG":
|
||||||
|
pil_format = "JPEG"
|
||||||
|
img.save(img_byte_arr, format=pil_format)
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
return img_byte_arr
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_bytesio(
|
||||||
|
image: torch.Tensor,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
total_pixels: int = 2048 * 2048,
|
||||||
|
mime_type: str = "image/png",
|
||||||
|
) -> BytesIO:
|
||||||
|
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input torch.Tensor image.
|
||||||
|
name: Optional filename for the BytesIO object.
|
||||||
|
total_pixels: Maximum total pixels for potential downscaling.
|
||||||
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Named BytesIO object containing the image data.
|
||||||
|
"""
|
||||||
|
if not mime_type:
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
|
||||||
|
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||||
|
img_binary.name = (
|
||||||
|
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||||
|
)
|
||||||
|
return img_binary
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_base64_string(
|
||||||
|
image_tensor: torch.Tensor,
|
||||||
|
total_pixels: int = 2048 * 2048,
|
||||||
|
mime_type: str = "image/png",
|
||||||
|
) -> str:
|
||||||
|
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_tensor: Input torch.Tensor image.
|
||||||
|
total_pixels: Maximum total pixels for potential downscaling.
|
||||||
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 encoded string of the image.
|
||||||
|
"""
|
||||||
|
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||||
|
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||||
|
img_bytes = img_byte_arr.getvalue()
|
||||||
|
# Encode bytes to base64 string
|
||||||
|
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||||
|
return base64_encoded_string
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_data_uri(
|
||||||
|
image_tensor: torch.Tensor,
|
||||||
|
total_pixels: int = 2048 * 2048,
|
||||||
|
mime_type: str = "image/png",
|
||||||
|
) -> str:
|
||||||
|
"""Converts a tensor image to a Data URI string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_tensor: Input torch.Tensor image.
|
||||||
|
total_pixels: Maximum total pixels for potential downscaling.
|
||||||
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Data URI string (e.g., 'data:image/png;base64,...').
|
||||||
|
"""
|
||||||
|
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||||
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|
||||||
|
|
||||||
|
def upload_file_to_comfyapi(
|
||||||
|
file_bytes_io: BytesIO,
|
||||||
|
filename: str,
|
||||||
|
upload_mime_type: str,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Uploads a single file to ComfyUI API and returns its download URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_bytes_io: BytesIO object containing the file data.
|
||||||
|
filename: The filename of the file.
|
||||||
|
upload_mime_type: MIME type of the file.
|
||||||
|
auth_token: Optional authentication token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The download URL for the uploaded file.
|
||||||
|
"""
|
||||||
|
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/customers/storage",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=UploadRequest,
|
||||||
|
response_model=UploadResponse,
|
||||||
|
),
|
||||||
|
request=request_object,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
response: UploadResponse = operation.execute()
|
||||||
|
upload_response = ApiClient.upload_file(
|
||||||
|
response.upload_url, file_bytes_io, content_type=upload_mime_type
|
||||||
|
)
|
||||||
|
upload_response.raise_for_status()
|
||||||
|
|
||||||
|
return response.download_url
|
||||||
|
|
||||||
|
|
||||||
|
def upload_video_to_comfyapi(
|
||||||
|
video: VideoInput,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
container: VideoContainer = VideoContainer.MP4,
|
||||||
|
codec: VideoCodec = VideoCodec.H264,
|
||||||
|
max_duration: Optional[int] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Uploads a single video to ComfyUI API and returns its download URL.
|
||||||
|
Uses the specified container and codec for saving the video before upload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video: VideoInput object (Comfy VIDEO type).
|
||||||
|
auth_token: Optional authentication token.
|
||||||
|
container: The video container format to use (default: MP4).
|
||||||
|
codec: The video codec to use (default: H264).
|
||||||
|
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The download URL for the uploaded video file.
|
||||||
|
"""
|
||||||
|
if max_duration is not None:
|
||||||
|
try:
|
||||||
|
actual_duration = video.duration_seconds
|
||||||
|
if actual_duration is not None and actual_duration > max_duration:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error getting video duration: {e}")
|
||||||
|
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||||
|
|
||||||
|
upload_mime_type = f"video/{container.value.lower()}"
|
||||||
|
filename = f"uploaded_video.{container.value.lower()}"
|
||||||
|
|
||||||
|
# Convert VideoInput to BytesIO using specified container/codec
|
||||||
|
video_bytes_io = io.BytesIO()
|
||||||
|
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||||
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
|
return upload_file_to_comfyapi(
|
||||||
|
video_bytes_io, filename, upload_mime_type, auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||||
|
the first item is taken.
|
||||||
|
"""
|
||||||
|
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||||
|
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||||
|
|
||||||
|
# If batch is > 1, take first item
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform[0]
|
||||||
|
|
||||||
|
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||||
|
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||||
|
if audio_data_np.dtype != np.float32:
|
||||||
|
audio_data_np = audio_data_np.astype(np.float32)
|
||||||
|
|
||||||
|
return audio_data_np
|
||||||
|
|
||||||
|
|
||||||
|
def audio_ndarray_to_bytesio(
|
||||||
|
audio_data_np: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
container_format: str = "mp4",
|
||||||
|
codec_name: str = "aac",
|
||||||
|
) -> BytesIO:
|
||||||
|
"""
|
||||||
|
Encodes a numpy array of audio data into a BytesIO object.
|
||||||
|
"""
|
||||||
|
audio_bytes_io = io.BytesIO()
|
||||||
|
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||||
|
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||||
|
frame = av.AudioFrame.from_ndarray(
|
||||||
|
audio_data_np,
|
||||||
|
format="fltp",
|
||||||
|
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||||
|
)
|
||||||
|
frame.sample_rate = sample_rate
|
||||||
|
frame.pts = 0
|
||||||
|
|
||||||
|
for packet in audio_stream.encode(frame):
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
# Flush stream
|
||||||
|
for packet in audio_stream.encode(None):
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
audio_bytes_io.seek(0)
|
||||||
|
return audio_bytes_io
|
||||||
|
|
||||||
|
|
||||||
|
def upload_audio_to_comfyapi(
|
||||||
|
audio: AudioInput,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
container_format: str = "mp4",
|
||||||
|
codec_name: str = "aac",
|
||||||
|
mime_type: str = "audio/mp4",
|
||||||
|
filename: str = "uploaded_audio.mp4",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||||
|
Encodes the raw waveform into the specified format before uploading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||||
|
auth_token: Optional authentication token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The download URL for the uploaded audio file.
|
||||||
|
"""
|
||||||
|
sample_rate: int = audio["sample_rate"]
|
||||||
|
waveform: torch.Tensor = audio["waveform"]
|
||||||
|
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||||
|
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||||
|
audio_data_np, sample_rate, container_format, codec_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
def upload_images_to_comfyapi(
|
||||||
|
image: torch.Tensor, max_images=8, auth_token=None, mime_type: Optional[str] = None
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Uploads images to ComfyUI API and returns download URLs.
|
||||||
|
To upload multiple images, stack them in the batch dimension first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input torch.Tensor image.
|
||||||
|
max_images: Maximum number of images to upload.
|
||||||
|
auth_token: Optional authentication token.
|
||||||
|
mime_type: Optional MIME type for the image.
|
||||||
|
"""
|
||||||
|
# if batch, try to upload each file if max_images is greater than 0
|
||||||
|
idx_image = 0
|
||||||
|
download_urls: list[str] = []
|
||||||
|
is_batch = len(image.shape) > 3
|
||||||
|
batch_length = 1
|
||||||
|
if is_batch:
|
||||||
|
batch_length = image.shape[0]
|
||||||
|
while True:
|
||||||
|
curr_image = image
|
||||||
|
if len(image.shape) > 3:
|
||||||
|
curr_image = image[idx_image]
|
||||||
|
# get BytesIO version of image
|
||||||
|
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
|
||||||
|
# first, request upload/download urls from comfy API
|
||||||
|
if not mime_type:
|
||||||
|
request_object = UploadRequest(file_name=img_binary.name)
|
||||||
|
else:
|
||||||
|
request_object = UploadRequest(
|
||||||
|
file_name=img_binary.name, content_type=mime_type
|
||||||
|
)
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/customers/storage",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=UploadRequest,
|
||||||
|
response_model=UploadResponse,
|
||||||
|
),
|
||||||
|
request=request_object,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
upload_response = ApiClient.upload_file(
|
||||||
|
response.upload_url, img_binary, content_type=mime_type
|
||||||
|
)
|
||||||
|
# verify success
|
||||||
|
try:
|
||||||
|
upload_response.raise_for_status()
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
raise ValueError(f"Could not upload one or more images: {e}") from e
|
||||||
|
# add download_url to list
|
||||||
|
download_urls.append(response.download_url)
|
||||||
|
|
||||||
|
idx_image += 1
|
||||||
|
# stop uploading additional files if done
|
||||||
|
if is_batch and max_images > 0:
|
||||||
|
if idx_image >= max_images:
|
||||||
|
break
|
||||||
|
if idx_image >= batch_length:
|
||||||
|
break
|
||||||
|
return download_urls
|
||||||
|
|
||||||
|
|
||||||
|
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
|
||||||
|
upscale_method="nearest-exact", crop="disabled",
|
||||||
|
allow_gradient=True, add_channel_dim=False):
|
||||||
|
"""
|
||||||
|
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||||
|
"""
|
||||||
|
_, H, W, _ = image.shape
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
mask = mask.movedim(-1,1)
|
||||||
|
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
|
||||||
|
mask = mask.movedim(1,-1)
|
||||||
|
if not add_channel_dim:
|
||||||
|
mask = mask.squeeze(-1)
|
||||||
|
if not allow_gradient:
|
||||||
|
mask = (mask > 0.5).float()
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
|
||||||
|
if strip_whitespace:
|
||||||
|
string = string.strip()
|
||||||
|
if min_length and len(string) < min_length:
|
||||||
|
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
|
||||||
|
if max_length and len(string) > max_length:
|
||||||
|
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
|
||||||
|
if not string:
|
||||||
|
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||||
17
comfy_api_nodes/apis/PixverseController.py
Normal file
17
comfy_api_nodes/apis/PixverseController.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# generated by datamodel-codegen:
|
||||||
|
# filename: filtered-openapi.yaml
|
||||||
|
# timestamp: 2025-04-29T23:44:54+00:00
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from . import PixverseDto
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseData(BaseModel):
|
||||||
|
ErrCode: Optional[int] = None
|
||||||
|
ErrMsg: Optional[str] = None
|
||||||
|
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
||||||
57
comfy_api_nodes/apis/PixverseDto.py
Normal file
57
comfy_api_nodes/apis/PixverseDto.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# generated by datamodel-codegen:
|
||||||
|
# filename: filtered-openapi.yaml
|
||||||
|
# timestamp: 2025-04-29T23:44:54+00:00
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class V2OpenAPII2VResp(BaseModel):
|
||||||
|
video_id: Optional[int] = Field(None, description='Video_id')
|
||||||
|
|
||||||
|
|
||||||
|
class V2OpenAPIT2VReq(BaseModel):
|
||||||
|
aspect_ratio: str = Field(
|
||||||
|
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
||||||
|
)
|
||||||
|
duration: int = Field(
|
||||||
|
...,
|
||||||
|
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
||||||
|
examples=[5],
|
||||||
|
)
|
||||||
|
model: str = Field(
|
||||||
|
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
||||||
|
)
|
||||||
|
motion_mode: Optional[str] = Field(
|
||||||
|
'normal',
|
||||||
|
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
||||||
|
examples=['normal'],
|
||||||
|
)
|
||||||
|
negative_prompt: Optional[str] = Field(
|
||||||
|
None, description='Negative prompt\n', max_length=2048
|
||||||
|
)
|
||||||
|
prompt: str = Field(..., description='Prompt', max_length=2048)
|
||||||
|
quality: str = Field(
|
||||||
|
...,
|
||||||
|
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
||||||
|
examples=['540p'],
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
||||||
|
style: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
||||||
|
examples=['anime'],
|
||||||
|
)
|
||||||
|
template_id: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description='Template ID (template_id must be activated before use)',
|
||||||
|
examples=[302325299692608],
|
||||||
|
)
|
||||||
|
water_mark: Optional[bool] = Field(
|
||||||
|
False,
|
||||||
|
description='Watermark (true: add watermark, false: no watermark)',
|
||||||
|
examples=[False],
|
||||||
|
)
|
||||||
3829
comfy_api_nodes/apis/__init__.py
Normal file
3829
comfy_api_nodes/apis/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
156
comfy_api_nodes/apis/bfl_api.py
Normal file
156
comfy_api_nodes/apis/bfl_api.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, confloat, conint
|
||||||
|
|
||||||
|
|
||||||
|
class BFLOutputFormat(str, Enum):
|
||||||
|
png = 'png'
|
||||||
|
jpeg = 'jpeg'
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxExpandImageRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||||
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
|
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||||
|
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
|
||||||
|
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
|
||||||
|
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
|
||||||
|
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
|
||||||
|
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||||
|
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||||
|
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||||
|
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||||
|
)
|
||||||
|
output_format: Optional[BFLOutputFormat] = Field(
|
||||||
|
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||||
|
)
|
||||||
|
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxFillImageRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||||
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
|
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||||
|
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||||
|
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||||
|
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||||
|
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||||
|
)
|
||||||
|
output_format: Optional[BFLOutputFormat] = Field(
|
||||||
|
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||||
|
)
|
||||||
|
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
|
||||||
|
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxCannyImageRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='Text prompt for image generation')
|
||||||
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
|
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||||
|
)
|
||||||
|
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
|
||||||
|
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
|
||||||
|
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||||
|
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||||
|
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||||
|
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||||
|
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||||
|
)
|
||||||
|
output_format: Optional[BFLOutputFormat] = Field(
|
||||||
|
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||||
|
)
|
||||||
|
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||||
|
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxDepthImageRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='Text prompt for image generation')
|
||||||
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
|
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||||
|
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||||
|
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||||
|
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||||
|
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||||
|
)
|
||||||
|
output_format: Optional[BFLOutputFormat] = Field(
|
||||||
|
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||||
|
)
|
||||||
|
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||||
|
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxProGenerateRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||||
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
|
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||||
|
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
|
||||||
|
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
|
||||||
|
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||||
|
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||||
|
)
|
||||||
|
output_format: Optional[BFLOutputFormat] = Field(
|
||||||
|
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||||
|
)
|
||||||
|
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||||
|
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||||
|
# None, description='Blend between the prompt and the image prompt.'
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||||
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
|
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||||
|
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||||
|
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||||
|
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||||
|
)
|
||||||
|
output_format: Optional[BFLOutputFormat] = Field(
|
||||||
|
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||||
|
)
|
||||||
|
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
|
||||||
|
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||||
|
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||||
|
None, description='Blend between the prompt and the image prompt.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxProGenerateResponse(BaseModel):
|
||||||
|
id: str = Field(..., description='The unique identifier for the generation task.')
|
||||||
|
polling_url: str = Field(..., description='URL to poll for the generation result.')
|
||||||
|
|
||||||
|
|
||||||
|
class BFLStatus(str, Enum):
|
||||||
|
task_not_found = "Task not found"
|
||||||
|
pending = "Pending"
|
||||||
|
request_moderated = "Request Moderated"
|
||||||
|
content_moderated = "Content Moderated"
|
||||||
|
ready = "Ready"
|
||||||
|
error = "Error"
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxProStatusResponse(BaseModel):
|
||||||
|
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||||
|
status: BFLStatus = Field(..., description="The status of the task.")
|
||||||
|
result: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description="The result of the task (null if not completed)."
|
||||||
|
)
|
||||||
|
progress: confloat(ge=0.0, le=1.0) = Field(
|
||||||
|
..., description="The progress of the task (0.0 to 1.0)."
|
||||||
|
)
|
||||||
|
details: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description="Additional details about the task (null if not available)."
|
||||||
|
)
|
||||||
616
comfy_api_nodes/apis/client.py
Normal file
616
comfy_api_nodes/apis/client.py
Normal file
@@ -0,0 +1,616 @@
|
|||||||
|
"""
|
||||||
|
API Client Framework for api.comfy.org.
|
||||||
|
|
||||||
|
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
||||||
|
It supports both synchronous and asynchronous API operations with proper type validation.
|
||||||
|
|
||||||
|
Key Components:
|
||||||
|
--------------
|
||||||
|
1. ApiClient - Handles HTTP requests with authentication and error handling
|
||||||
|
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
||||||
|
3. ApiOperation - Executes a single synchronous API operation
|
||||||
|
|
||||||
|
Usage Examples:
|
||||||
|
--------------
|
||||||
|
|
||||||
|
# Example 1: Synchronous API Operation
|
||||||
|
# ------------------------------------
|
||||||
|
# For a simple API call that returns the result immediately:
|
||||||
|
|
||||||
|
# 1. Create the API client
|
||||||
|
api_client = ApiClient(
|
||||||
|
base_url="https://api.example.com",
|
||||||
|
api_key="your_api_key_here",
|
||||||
|
timeout=30.0,
|
||||||
|
verify_ssl=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Define the endpoint
|
||||||
|
user_info_endpoint = ApiEndpoint(
|
||||||
|
path="/v1/users/me",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest, # No request body needed
|
||||||
|
response_model=UserProfile, # Pydantic model for the response
|
||||||
|
query_params=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Create the request object
|
||||||
|
request = EmptyRequest()
|
||||||
|
|
||||||
|
# 4. Create and execute the operation
|
||||||
|
operation = ApiOperation(
|
||||||
|
endpoint=user_info_endpoint,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
user_profile = operation.execute(client=api_client) # Returns immediately with the result
|
||||||
|
|
||||||
|
|
||||||
|
# Example 2: Asynchronous API Operation with Polling
|
||||||
|
# -------------------------------------------------
|
||||||
|
# For an API that starts a task and requires polling for completion:
|
||||||
|
|
||||||
|
# 1. Define the endpoints (initial request and polling)
|
||||||
|
generate_image_endpoint = ApiEndpoint(
|
||||||
|
path="/v1/images/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=ImageGenerationRequest,
|
||||||
|
response_model=TaskCreatedResponse,
|
||||||
|
query_params=None
|
||||||
|
)
|
||||||
|
|
||||||
|
check_task_endpoint = ApiEndpoint(
|
||||||
|
path="/v1/tasks/{task_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=ImageGenerationResult,
|
||||||
|
query_params=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Create the request object
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt="a beautiful sunset over mountains",
|
||||||
|
width=1024,
|
||||||
|
height=1024,
|
||||||
|
num_images=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Create and execute the polling operation
|
||||||
|
operation = PollingOperation(
|
||||||
|
initial_endpoint=generate_image_endpoint,
|
||||||
|
initial_request=request,
|
||||||
|
poll_endpoint=check_task_endpoint,
|
||||||
|
task_id_field="task_id",
|
||||||
|
status_field="status",
|
||||||
|
completed_statuses=["completed"],
|
||||||
|
failed_statuses=["failed", "error"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# This will make the initial request and then poll until completion
|
||||||
|
result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import io
|
||||||
|
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable
|
||||||
|
from enum import Enum
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from comfy import utils
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
R = TypeVar("R", bound=BaseModel)
|
||||||
|
P = TypeVar("P", bound=BaseModel) # For poll response
|
||||||
|
|
||||||
|
PROGRESS_BAR_MAX = 100
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyRequest(BaseModel):
|
||||||
|
"""Base class for empty request bodies.
|
||||||
|
For GET requests, fields will be sent as query parameters."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UploadRequest(BaseModel):
|
||||||
|
file_name: str = Field(..., description="Filename to upload")
|
||||||
|
content_type: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadResponse(BaseModel):
|
||||||
|
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||||
|
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||||
|
|
||||||
|
|
||||||
|
class HttpMethod(str, Enum):
|
||||||
|
GET = "GET"
|
||||||
|
POST = "POST"
|
||||||
|
PUT = "PUT"
|
||||||
|
DELETE = "DELETE"
|
||||||
|
PATCH = "PATCH"
|
||||||
|
|
||||||
|
|
||||||
|
class ApiClient:
|
||||||
|
"""
|
||||||
|
Client for making HTTP requests to an API with authentication and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: float = 3600.0,
|
||||||
|
verify_ssl: bool = True,
|
||||||
|
):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.timeout = timeout
|
||||||
|
self.verify_ssl = verify_ssl
|
||||||
|
|
||||||
|
def _create_json_payload_args(
|
||||||
|
self,
|
||||||
|
data: Optional[Dict[str, Any]] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"json": data,
|
||||||
|
"headers": headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_form_data_args(
|
||||||
|
self,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
files: Dict[str, Any],
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
multipart_parser = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if headers and "Content-Type" in headers:
|
||||||
|
del headers["Content-Type"]
|
||||||
|
|
||||||
|
if multipart_parser:
|
||||||
|
data = multipart_parser(data)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": data,
|
||||||
|
"files": files,
|
||||||
|
"headers": headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_urlencoded_form_data_args(
|
||||||
|
self,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
headers = headers or {}
|
||||||
|
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": data,
|
||||||
|
"headers": headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_headers(self) -> Dict[str, str]:
|
||||||
|
"""Get headers for API requests, including authentication if available"""
|
||||||
|
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||||
|
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
data: Optional[Dict[str, Any]] = None,
|
||||||
|
files: Optional[Dict[str, Any]] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
content_type: str = "application/json",
|
||||||
|
multipart_parser: Callable = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Make an HTTP request to the API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, etc.)
|
||||||
|
path: API endpoint path (will be joined with base_url)
|
||||||
|
params: Query parameters
|
||||||
|
data: body data
|
||||||
|
files: Files to upload
|
||||||
|
headers: Additional headers
|
||||||
|
content_type: Content type of the request. Defaults to application/json.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.RequestException: If the request fails
|
||||||
|
"""
|
||||||
|
url = urljoin(self.base_url, path)
|
||||||
|
self.check_auth_token(self.api_key)
|
||||||
|
# Combine default headers with any provided headers
|
||||||
|
request_headers = self.get_headers()
|
||||||
|
if headers:
|
||||||
|
request_headers.update(headers)
|
||||||
|
|
||||||
|
# Let requests handle the content type when files are present.
|
||||||
|
if files:
|
||||||
|
del request_headers["Content-Type"]
|
||||||
|
|
||||||
|
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
||||||
|
logging.debug(f"[DEBUG] Files: {files}")
|
||||||
|
logging.debug(f"[DEBUG] Params: {params}")
|
||||||
|
logging.debug(f"[DEBUG] Data: {data}")
|
||||||
|
|
||||||
|
if content_type == "application/x-www-form-urlencoded":
|
||||||
|
payload_args = self._create_urlencoded_form_data_args(data, request_headers)
|
||||||
|
elif content_type == "multipart/form-data":
|
||||||
|
payload_args = self._create_form_data_args(
|
||||||
|
data, files, request_headers, multipart_parser
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
payload_args = self._create_json_payload_args(data, request_headers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
params=params,
|
||||||
|
timeout=self.timeout,
|
||||||
|
verify=self.verify_ssl,
|
||||||
|
**payload_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Raise exception for error status codes
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.ConnectionError:
|
||||||
|
raise Exception(
|
||||||
|
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.Timeout:
|
||||||
|
raise Exception(
|
||||||
|
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
status_code = e.response.status_code if hasattr(e, "response") else None
|
||||||
|
error_message = f"HTTP Error: {str(e)}"
|
||||||
|
|
||||||
|
# Try to extract detailed error message from JSON response
|
||||||
|
try:
|
||||||
|
if hasattr(e, "response") and e.response.content:
|
||||||
|
error_json = e.response.json()
|
||||||
|
if "error" in error_json and "message" in error_json["error"]:
|
||||||
|
error_message = f"API Error: {error_json['error']['message']}"
|
||||||
|
if "type" in error_json["error"]:
|
||||||
|
error_message += f" (Type: {error_json['error']['type']})"
|
||||||
|
else:
|
||||||
|
error_message = f"API Error: {error_json}"
|
||||||
|
except Exception as json_error:
|
||||||
|
# If we can't parse the JSON, fall back to the original error message
|
||||||
|
logging.debug(
|
||||||
|
f"[DEBUG] Failed to parse error response: {str(json_error)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
|
||||||
|
if hasattr(e, "response") and e.response.content:
|
||||||
|
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
||||||
|
if status_code == 401:
|
||||||
|
error_message = "Unauthorized: Please login first to use this node."
|
||||||
|
if status_code == 402:
|
||||||
|
error_message = "Payment Required: Please add credits to your account to use this node."
|
||||||
|
if status_code == 409:
|
||||||
|
error_message = "There is a problem with your account. Please contact support@comfy.org. "
|
||||||
|
if status_code == 429:
|
||||||
|
error_message = "Rate Limit Exceeded: Please try again later."
|
||||||
|
raise Exception(error_message)
|
||||||
|
|
||||||
|
# Parse and return JSON response
|
||||||
|
if response.content:
|
||||||
|
return response.json()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_auth_token(self, auth_token):
|
||||||
|
"""Verify that an auth token is present."""
|
||||||
|
if auth_token is None:
|
||||||
|
raise Exception("Unauthorized: Please login first to use this node.")
|
||||||
|
return auth_token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def upload_file(
|
||||||
|
upload_url: str,
|
||||||
|
file: io.BytesIO | str,
|
||||||
|
content_type: str | None = None,
|
||||||
|
):
|
||||||
|
"""Upload a file to the API. Make sure the file has a filename equal to what the url expects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
upload_url: The URL to upload to
|
||||||
|
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
||||||
|
mime_type: Optional mime type to set for the upload
|
||||||
|
"""
|
||||||
|
headers = {}
|
||||||
|
if content_type:
|
||||||
|
headers["Content-Type"] = content_type
|
||||||
|
|
||||||
|
if isinstance(file, io.BytesIO):
|
||||||
|
file.seek(0) # Ensure we're at the start of the file
|
||||||
|
data = file.read()
|
||||||
|
return requests.put(upload_url, data=data, headers=headers)
|
||||||
|
elif isinstance(file, str):
|
||||||
|
with open(file, "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
return requests.put(upload_url, data=data, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiEndpoint(Generic[T, R]):
|
||||||
|
"""Defines an API endpoint with its request and response types"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
method: HttpMethod,
|
||||||
|
request_model: Type[T],
|
||||||
|
response_model: Type[R],
|
||||||
|
query_params: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize an API endpoint definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The URL path for this endpoint, can include placeholders like {id}
|
||||||
|
method: The HTTP method to use (GET, POST, etc.)
|
||||||
|
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
||||||
|
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
||||||
|
query_params: Optional dictionary of query parameters to include in the request
|
||||||
|
"""
|
||||||
|
self.path = path
|
||||||
|
self.method = method
|
||||||
|
self.request_model = request_model
|
||||||
|
self.response_model = response_model
|
||||||
|
self.query_params = query_params or {}
|
||||||
|
|
||||||
|
|
||||||
|
class SynchronousOperation(Generic[T, R]):
|
||||||
|
"""
|
||||||
|
Represents a single synchronous API operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: ApiEndpoint[T, R],
|
||||||
|
request: T,
|
||||||
|
files: Optional[Dict[str, Any]] = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
timeout: float = 604800.0,
|
||||||
|
verify_ssl: bool = True,
|
||||||
|
content_type: str = "application/json",
|
||||||
|
multipart_parser: Callable = None,
|
||||||
|
):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.request = request
|
||||||
|
self.response = None
|
||||||
|
self.error = None
|
||||||
|
self.api_base: str = api_base or args.comfy_api_base
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.timeout = timeout
|
||||||
|
self.verify_ssl = verify_ssl
|
||||||
|
self.files = files
|
||||||
|
self.content_type = content_type
|
||||||
|
self.multipart_parser = multipart_parser
|
||||||
|
def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||||
|
"""Execute the API operation using the provided client or create one"""
|
||||||
|
try:
|
||||||
|
# Create client if not provided
|
||||||
|
if client is None:
|
||||||
|
client = ApiClient(
|
||||||
|
base_url=self.api_base,
|
||||||
|
api_key=self.auth_token,
|
||||||
|
timeout=self.timeout,
|
||||||
|
verify_ssl=self.verify_ssl,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert request model to dict, but use None for EmptyRequest
|
||||||
|
request_dict = (
|
||||||
|
None
|
||||||
|
if isinstance(self.request, EmptyRequest)
|
||||||
|
else self.request.model_dump(exclude_none=True)
|
||||||
|
)
|
||||||
|
if request_dict:
|
||||||
|
for key, value in request_dict.items():
|
||||||
|
if isinstance(value, Enum):
|
||||||
|
request_dict[key] = value.value
|
||||||
|
|
||||||
|
if request_dict:
|
||||||
|
for key, value in request_dict.items():
|
||||||
|
if isinstance(value, Enum):
|
||||||
|
request_dict[key] = value.value
|
||||||
|
|
||||||
|
# Debug log for request
|
||||||
|
logging.debug(
|
||||||
|
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
||||||
|
)
|
||||||
|
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||||||
|
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
resp = client.request(
|
||||||
|
method=self.endpoint.method.value,
|
||||||
|
path=self.endpoint.path,
|
||||||
|
data=request_dict,
|
||||||
|
params=self.endpoint.query_params,
|
||||||
|
files=self.files,
|
||||||
|
content_type=self.content_type,
|
||||||
|
multipart_parser=self.multipart_parser
|
||||||
|
)
|
||||||
|
|
||||||
|
# Debug log for response
|
||||||
|
logging.debug("=" * 50)
|
||||||
|
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||||
|
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||||
|
logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}")
|
||||||
|
logging.debug("=" * 50)
|
||||||
|
|
||||||
|
# Parse and return the response
|
||||||
|
return self._parse_response(resp)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"[DEBUG] API Exception: {str(e)}")
|
||||||
|
raise Exception(str(e))
|
||||||
|
|
||||||
|
def _parse_response(self, resp):
|
||||||
|
"""Parse response data - can be overridden by subclasses"""
|
||||||
|
# The response is already the complete object, don't extract just the "data" field
|
||||||
|
# as that would lose the outer structure (created timestamp, etc.)
|
||||||
|
|
||||||
|
# Parse response using the provided model
|
||||||
|
self.response = self.endpoint.response_model.model_validate(resp)
|
||||||
|
logging.debug(f"[DEBUG] Parsed Response: {self.response}")
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
"""Enum for task status values"""
|
||||||
|
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
PENDING = "pending"
|
||||||
|
|
||||||
|
|
||||||
|
class PollingOperation(Generic[T, R]):
|
||||||
|
"""
|
||||||
|
Represents an asynchronous API operation that requires polling for completion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
||||||
|
completed_statuses: list,
|
||||||
|
failed_statuses: list,
|
||||||
|
status_extractor: Callable[[R], str],
|
||||||
|
progress_extractor: Callable[[R], float] = None,
|
||||||
|
request: Optional[T] = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
poll_interval: float = 5.0,
|
||||||
|
):
|
||||||
|
self.poll_endpoint = poll_endpoint
|
||||||
|
self.request = request
|
||||||
|
self.api_base: str = api_base or args.comfy_api_base
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.poll_interval = poll_interval
|
||||||
|
|
||||||
|
# Polling configuration
|
||||||
|
self.status_extractor = status_extractor or (
|
||||||
|
lambda x: getattr(x, "status", None)
|
||||||
|
)
|
||||||
|
self.progress_extractor = progress_extractor
|
||||||
|
self.completed_statuses = completed_statuses
|
||||||
|
self.failed_statuses = failed_statuses
|
||||||
|
|
||||||
|
# For storing response data
|
||||||
|
self.final_response = None
|
||||||
|
self.error = None
|
||||||
|
|
||||||
|
def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||||
|
"""Execute the polling operation using the provided client. If failed, raise an exception."""
|
||||||
|
try:
|
||||||
|
if client is None:
|
||||||
|
client = ApiClient(
|
||||||
|
base_url=self.api_base,
|
||||||
|
api_key=self.auth_token,
|
||||||
|
)
|
||||||
|
return self._poll_until_complete(client)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error during polling: {str(e)}")
|
||||||
|
|
||||||
|
def _check_task_status(self, response: R) -> TaskStatus:
|
||||||
|
"""Check task status using the status extractor function"""
|
||||||
|
try:
|
||||||
|
status = self.status_extractor(response)
|
||||||
|
if status in self.completed_statuses:
|
||||||
|
return TaskStatus.COMPLETED
|
||||||
|
elif status in self.failed_statuses:
|
||||||
|
return TaskStatus.FAILED
|
||||||
|
return TaskStatus.PENDING
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error extracting status: {e}")
|
||||||
|
return TaskStatus.PENDING
|
||||||
|
|
||||||
|
def _poll_until_complete(self, client: ApiClient) -> R:
|
||||||
|
"""Poll until the task is complete"""
|
||||||
|
poll_count = 0
|
||||||
|
if self.progress_extractor:
|
||||||
|
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
poll_count += 1
|
||||||
|
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
||||||
|
|
||||||
|
request_dict = (
|
||||||
|
self.request.model_dump(exclude_none=True)
|
||||||
|
if self.request is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if poll_count == 1:
|
||||||
|
logging.debug(
|
||||||
|
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
|
||||||
|
)
|
||||||
|
logging.debug(
|
||||||
|
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query task status
|
||||||
|
resp = client.request(
|
||||||
|
method=self.poll_endpoint.method.value,
|
||||||
|
path=self.poll_endpoint.path,
|
||||||
|
params=self.poll_endpoint.query_params,
|
||||||
|
data=request_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
response_obj = self.poll_endpoint.response_model.model_validate(resp)
|
||||||
|
# Check if task is complete
|
||||||
|
status = self._check_task_status(response_obj)
|
||||||
|
logging.debug(f"[DEBUG] Task Status: {status}")
|
||||||
|
|
||||||
|
# If progress extractor is provided, extract progress
|
||||||
|
if self.progress_extractor:
|
||||||
|
new_progress = self.progress_extractor(response_obj)
|
||||||
|
if new_progress is not None:
|
||||||
|
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||||||
|
|
||||||
|
if status == TaskStatus.COMPLETED:
|
||||||
|
logging.debug("[DEBUG] Task completed successfully")
|
||||||
|
self.final_response = response_obj
|
||||||
|
if self.progress_extractor:
|
||||||
|
progress.update(100)
|
||||||
|
return self.final_response
|
||||||
|
elif status == TaskStatus.FAILED:
|
||||||
|
message = f"Task failed: {json.dumps(resp)}"
|
||||||
|
logging.error(f"[DEBUG] {message}")
|
||||||
|
raise Exception(message)
|
||||||
|
else:
|
||||||
|
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||||||
|
|
||||||
|
# Wait before polling again
|
||||||
|
logging.debug(
|
||||||
|
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
|
||||||
|
)
|
||||||
|
time.sleep(self.poll_interval)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
||||||
|
raise Exception(f"Error while polling: {str(e)}")
|
||||||
253
comfy_api_nodes/apis/luma_api.py
Normal file
253
comfy_api_nodes/apis/luma_api.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, confloat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LumaIO:
|
||||||
|
LUMA_REF = "LUMA_REF"
|
||||||
|
LUMA_CONCEPTS = "LUMA_CONCEPTS"
|
||||||
|
|
||||||
|
|
||||||
|
class LumaReference:
|
||||||
|
def __init__(self, image: torch.Tensor, weight: float):
|
||||||
|
self.image = image
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
def create_api_model(self, download_url: str):
|
||||||
|
return LumaImageRef(url=download_url, weight=self.weight)
|
||||||
|
|
||||||
|
class LumaReferenceChain:
|
||||||
|
def __init__(self, first_ref: LumaReference=None):
|
||||||
|
self.refs: list[LumaReference] = []
|
||||||
|
if first_ref:
|
||||||
|
self.refs.append(first_ref)
|
||||||
|
|
||||||
|
def add(self, luma_ref: LumaReference=None):
|
||||||
|
self.refs.append(luma_ref)
|
||||||
|
|
||||||
|
def create_api_model(self, download_urls: list[str], max_refs=4):
|
||||||
|
if len(self.refs) == 0:
|
||||||
|
return None
|
||||||
|
api_refs: list[LumaImageRef] = []
|
||||||
|
for ref, url in zip(self.refs, download_urls):
|
||||||
|
api_ref = LumaImageRef(url=url, weight=ref.weight)
|
||||||
|
api_refs.append(api_ref)
|
||||||
|
return api_refs
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c = LumaReferenceChain()
|
||||||
|
for ref in self.refs:
|
||||||
|
c.add(ref)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class LumaConcept:
|
||||||
|
def __init__(self, key: str):
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
|
||||||
|
class LumaConceptChain:
|
||||||
|
def __init__(self, str_list: list[str] = None):
|
||||||
|
self.concepts: list[LumaConcept] = []
|
||||||
|
if str_list is not None:
|
||||||
|
for c in str_list:
|
||||||
|
if c != "None":
|
||||||
|
self.add(LumaConcept(key=c))
|
||||||
|
|
||||||
|
def add(self, concept: LumaConcept):
|
||||||
|
self.concepts.append(concept)
|
||||||
|
|
||||||
|
def create_api_model(self):
|
||||||
|
if len(self.concepts) == 0:
|
||||||
|
return None
|
||||||
|
api_concepts: list[LumaConceptObject] = []
|
||||||
|
for concept in self.concepts:
|
||||||
|
if concept.key == "None":
|
||||||
|
continue
|
||||||
|
api_concepts.append(LumaConceptObject(key=concept.key))
|
||||||
|
if len(api_concepts) == 0:
|
||||||
|
return None
|
||||||
|
return api_concepts
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c = LumaConceptChain()
|
||||||
|
for concept in self.concepts:
|
||||||
|
c.add(concept)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def clone_and_merge(self, other: LumaConceptChain):
|
||||||
|
c = self.clone()
|
||||||
|
for concept in other.concepts:
|
||||||
|
c.add(concept)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def get_luma_concepts(include_none=False):
|
||||||
|
concepts = []
|
||||||
|
if include_none:
|
||||||
|
concepts.append("None")
|
||||||
|
return concepts + [
|
||||||
|
"truck_left",
|
||||||
|
"pan_right",
|
||||||
|
"pedestal_down",
|
||||||
|
"low_angle",
|
||||||
|
"pedestal_up",
|
||||||
|
"selfie",
|
||||||
|
"pan_left",
|
||||||
|
"roll_right",
|
||||||
|
"zoom_in",
|
||||||
|
"over_the_shoulder",
|
||||||
|
"orbit_right",
|
||||||
|
"orbit_left",
|
||||||
|
"static",
|
||||||
|
"tiny_planet",
|
||||||
|
"high_angle",
|
||||||
|
"bolt_cam",
|
||||||
|
"dolly_zoom",
|
||||||
|
"overhead",
|
||||||
|
"zoom_out",
|
||||||
|
"handheld",
|
||||||
|
"roll_left",
|
||||||
|
"pov",
|
||||||
|
"aerial_drone",
|
||||||
|
"push_in",
|
||||||
|
"crane_down",
|
||||||
|
"truck_right",
|
||||||
|
"tilt_down",
|
||||||
|
"elevator_doors",
|
||||||
|
"tilt_up",
|
||||||
|
"ground_level",
|
||||||
|
"pull_out",
|
||||||
|
"aerial",
|
||||||
|
"crane_up",
|
||||||
|
"eye_level"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LumaImageModel(str, Enum):
|
||||||
|
photon_1 = "photon-1"
|
||||||
|
photon_flash_1 = "photon-flash-1"
|
||||||
|
|
||||||
|
|
||||||
|
class LumaVideoModel(str, Enum):
|
||||||
|
ray_2 = "ray-2"
|
||||||
|
ray_flash_2 = "ray-flash-2"
|
||||||
|
ray_1_6 = "ray-1-6"
|
||||||
|
|
||||||
|
|
||||||
|
class LumaAspectRatio(str, Enum):
|
||||||
|
ratio_1_1 = "1:1"
|
||||||
|
ratio_16_9 = "16:9"
|
||||||
|
ratio_9_16 = "9:16"
|
||||||
|
ratio_4_3 = "4:3"
|
||||||
|
ratio_3_4 = "3:4"
|
||||||
|
ratio_21_9 = "21:9"
|
||||||
|
ratio_9_21 = "9:21"
|
||||||
|
|
||||||
|
|
||||||
|
class LumaVideoOutputResolution(str, Enum):
|
||||||
|
res_540p = "540p"
|
||||||
|
res_720p = "720p"
|
||||||
|
res_1080p = "1080p"
|
||||||
|
res_4k = "4k"
|
||||||
|
|
||||||
|
|
||||||
|
class LumaVideoModelOutputDuration(str, Enum):
|
||||||
|
dur_5s = "5s"
|
||||||
|
dur_9s = "9s"
|
||||||
|
|
||||||
|
|
||||||
|
class LumaGenerationType(str, Enum):
|
||||||
|
video = 'video'
|
||||||
|
image = 'image'
|
||||||
|
|
||||||
|
|
||||||
|
class LumaState(str, Enum):
|
||||||
|
queued = "queued"
|
||||||
|
dreaming = "dreaming"
|
||||||
|
completed = "completed"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class LumaAssets(BaseModel):
|
||||||
|
video: Optional[str] = Field(None, description='The URL of the video')
|
||||||
|
image: Optional[str] = Field(None, description='The URL of the image')
|
||||||
|
progress_video: Optional[str] = Field(None, description='The URL of the progress video')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaImageRef(BaseModel):
|
||||||
|
'''Used for image gen'''
|
||||||
|
url: str = Field(..., description='The URL of the image reference')
|
||||||
|
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaImageReference(BaseModel):
|
||||||
|
'''Used for video gen'''
|
||||||
|
type: Optional[str] = Field('image', description='Input type, defaults to image')
|
||||||
|
url: str = Field(..., description='The URL of the image')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaModifyImageRef(BaseModel):
|
||||||
|
url: str = Field(..., description='The URL of the image reference')
|
||||||
|
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaCharacterRef(BaseModel):
|
||||||
|
identity0: LumaImageIdentity = Field(..., description='The image identity object')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaImageIdentity(BaseModel):
|
||||||
|
images: list[str] = Field(..., description='The URLs of the image identity')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaGenerationReference(BaseModel):
|
||||||
|
type: str = Field('generation', description='Input type, defaults to generation')
|
||||||
|
id: str = Field(..., description='The ID of the generation')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaKeyframes(BaseModel):
|
||||||
|
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||||
|
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaConceptObject(BaseModel):
|
||||||
|
key: str = Field(..., description='Camera Concept name')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaImageGenerationRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='The prompt of the generation')
|
||||||
|
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
|
||||||
|
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
|
||||||
|
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
|
||||||
|
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
|
||||||
|
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
|
||||||
|
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaGenerationRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='The prompt of the generation')
|
||||||
|
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
|
||||||
|
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
|
||||||
|
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
|
||||||
|
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
|
||||||
|
loop: Optional[bool] = Field(None, description='Whether to loop the video')
|
||||||
|
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
|
||||||
|
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
|
||||||
|
|
||||||
|
|
||||||
|
class LumaGeneration(BaseModel):
|
||||||
|
id: str = Field(..., description='The ID of the generation')
|
||||||
|
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
|
||||||
|
state: LumaState = Field(..., description='The state of the generation')
|
||||||
|
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
|
||||||
|
created_at: str = Field(..., description='The date and time when the generation was created')
|
||||||
|
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
|
||||||
|
model: str = Field(..., description='The model used for the generation')
|
||||||
|
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
|
||||||
146
comfy_api_nodes/apis/pixverse_api.py
Normal file
146
comfy_api_nodes/apis/pixverse_api.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
pixverse_templates = {
|
||||||
|
"Microwave": 324641385496960,
|
||||||
|
"Suit Swagger": 328545151283968,
|
||||||
|
"Anything, Robot": 313358700761536,
|
||||||
|
"Subject 3 Fever": 327828816843648,
|
||||||
|
"kiss kiss": 315446315336768,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseIO:
|
||||||
|
TEMPLATE = "PIXVERSE_TEMPLATE"
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseStatus(int, Enum):
|
||||||
|
successful = 1
|
||||||
|
generating = 5
|
||||||
|
deleted = 6
|
||||||
|
contents_moderation = 7
|
||||||
|
failed = 8
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseAspectRatio(str, Enum):
|
||||||
|
ratio_16_9 = "16:9"
|
||||||
|
ratio_4_3 = "4:3"
|
||||||
|
ratio_1_1 = "1:1"
|
||||||
|
ratio_3_4 = "3:4"
|
||||||
|
ratio_9_16 = "9:16"
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseQuality(str, Enum):
|
||||||
|
res_360p = "360p"
|
||||||
|
res_540p = "540p"
|
||||||
|
res_720p = "720p"
|
||||||
|
res_1080p = "1080p"
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseDuration(int, Enum):
|
||||||
|
dur_5 = 5
|
||||||
|
dur_8 = 8
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseMotionMode(str, Enum):
|
||||||
|
normal = "normal"
|
||||||
|
fast = "fast"
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseStyle(str, Enum):
|
||||||
|
anime = "anime"
|
||||||
|
animation_3d = "3d_animation"
|
||||||
|
clay = "clay"
|
||||||
|
comic = "comic"
|
||||||
|
cyberpunk = "cyberpunk"
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: forgoing descriptions for now in return for dev speed
|
||||||
|
class PixverseTextVideoRequest(BaseModel):
|
||||||
|
aspect_ratio: PixverseAspectRatio = Field(...)
|
||||||
|
quality: PixverseQuality = Field(...)
|
||||||
|
duration: PixverseDuration = Field(...)
|
||||||
|
model: Optional[str] = Field("v3.5")
|
||||||
|
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
style: Optional[str] = Field(None)
|
||||||
|
template_id: Optional[int] = Field(None)
|
||||||
|
water_mark: Optional[bool] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseImageVideoRequest(BaseModel):
|
||||||
|
quality: PixverseQuality = Field(...)
|
||||||
|
duration: PixverseDuration = Field(...)
|
||||||
|
img_id: int = Field(...)
|
||||||
|
model: Optional[str] = Field("v3.5")
|
||||||
|
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
style: Optional[str] = Field(None)
|
||||||
|
template_id: Optional[int] = Field(None)
|
||||||
|
water_mark: Optional[bool] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseTransitionVideoRequest(BaseModel):
|
||||||
|
quality: PixverseQuality = Field(...)
|
||||||
|
duration: PixverseDuration = Field(...)
|
||||||
|
first_frame_img: int = Field(...)
|
||||||
|
last_frame_img: int = Field(...)
|
||||||
|
model: Optional[str] = Field("v3.5")
|
||||||
|
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
# negative_prompt: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
# style: Optional[str] = Field(None)
|
||||||
|
# template_id: Optional[int] = Field(None)
|
||||||
|
# water_mark: Optional[bool] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseImageUploadResponse(BaseModel):
|
||||||
|
ErrCode: Optional[int] = None
|
||||||
|
ErrMsg: Optional[str] = None
|
||||||
|
Resp: Optional[PixverseImgIdResponseObject] = Field(None, alias='Resp')
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseImgIdResponseObject(BaseModel):
|
||||||
|
img_id: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseVideoResponse(BaseModel):
|
||||||
|
ErrCode: Optional[int] = Field(None)
|
||||||
|
ErrMsg: Optional[str] = Field(None)
|
||||||
|
Resp: Optional[PixverseVideoIdResponseObject] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseVideoIdResponseObject(BaseModel):
|
||||||
|
video_id: int = Field(..., description='Video_id')
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseGenerationStatusResponse(BaseModel):
|
||||||
|
ErrCode: Optional[int] = Field(None)
|
||||||
|
ErrMsg: Optional[str] = Field(None)
|
||||||
|
Resp: Optional[PixverseGenerationStatusResponseObject] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseGenerationStatusResponseObject(BaseModel):
|
||||||
|
create_time: Optional[str] = Field(None)
|
||||||
|
id: Optional[int] = Field(None)
|
||||||
|
modify_time: Optional[str] = Field(None)
|
||||||
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
outputHeight: Optional[int] = Field(None)
|
||||||
|
outputWidth: Optional[int] = Field(None)
|
||||||
|
prompt: Optional[str] = Field(None)
|
||||||
|
resolution_ratio: Optional[int] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
size: Optional[int] = Field(None)
|
||||||
|
status: Optional[int] = Field(None)
|
||||||
|
style: Optional[str] = Field(None)
|
||||||
|
url: Optional[str] = Field(None)
|
||||||
263
comfy_api_nodes/apis/recraft_api.py
Normal file
263
comfy_api_nodes/apis/recraft_api.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, conint, confloat
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftColor:
|
||||||
|
def __init__(self, r: int, g: int, b: int):
|
||||||
|
self.color = [r, g, b]
|
||||||
|
|
||||||
|
def create_api_model(self):
|
||||||
|
return RecraftColorObject(rgb=self.color)
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftColorChain:
|
||||||
|
def __init__(self):
|
||||||
|
self.colors: list[RecraftColor] = []
|
||||||
|
|
||||||
|
def get_first(self):
|
||||||
|
if len(self.colors) > 0:
|
||||||
|
return self.colors[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add(self, color: RecraftColor):
|
||||||
|
self.colors.append(color)
|
||||||
|
|
||||||
|
def create_api_model(self):
|
||||||
|
if not self.colors:
|
||||||
|
return None
|
||||||
|
colors_api = [x.create_api_model() for x in self.colors]
|
||||||
|
return colors_api
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c = RecraftColorChain()
|
||||||
|
for color in self.colors:
|
||||||
|
c.add(color)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def clone_and_merge(self, other: RecraftColorChain):
|
||||||
|
c = self.clone()
|
||||||
|
for color in other.colors:
|
||||||
|
c.add(color)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftControls:
|
||||||
|
def __init__(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None,
|
||||||
|
artistic_level: int=None, no_text: bool=None):
|
||||||
|
self.colors = colors
|
||||||
|
self.background_color = background_color
|
||||||
|
self.artistic_level = artistic_level
|
||||||
|
self.no_text = no_text
|
||||||
|
|
||||||
|
def create_api_model(self):
|
||||||
|
if self.colors is None and self.background_color is None and self.artistic_level is None and self.no_text is None:
|
||||||
|
return None
|
||||||
|
colors_api = None
|
||||||
|
background_color_api = None
|
||||||
|
if self.colors:
|
||||||
|
colors_api = self.colors.create_api_model()
|
||||||
|
if self.background_color:
|
||||||
|
first_background = self.background_color.get_first()
|
||||||
|
background_color_api = first_background.create_api_model() if first_background else None
|
||||||
|
|
||||||
|
return RecraftControlsObject(colors=colors_api, background_color=background_color_api,
|
||||||
|
artistic_level=self.artistic_level, no_text=self.no_text)
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftStyle:
|
||||||
|
def __init__(self, style: str=None, substyle: str=None, style_id: str=None):
|
||||||
|
self.style = style
|
||||||
|
if substyle == "None":
|
||||||
|
substyle = None
|
||||||
|
self.substyle = substyle
|
||||||
|
self.style_id = style_id
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftIO:
|
||||||
|
STYLEV3 = "RECRAFT_V3_STYLE"
|
||||||
|
SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class
|
||||||
|
COLOR = "RECRAFT_COLOR"
|
||||||
|
CONTROLS = "RECRAFT_CONTROLS"
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftStyleV3(str, Enum):
|
||||||
|
#any = 'any' NOTE: this does not work for some reason... why?
|
||||||
|
realistic_image = 'realistic_image'
|
||||||
|
digital_illustration = 'digital_illustration'
|
||||||
|
vector_illustration = 'vector_illustration'
|
||||||
|
logo_raster = 'logo_raster'
|
||||||
|
|
||||||
|
|
||||||
|
def get_v3_substyles(style_v3: str, include_none=True) -> list[str]:
|
||||||
|
substyles: list[str] = []
|
||||||
|
if include_none:
|
||||||
|
substyles.append("None")
|
||||||
|
return substyles + dict_recraft_substyles_v3.get(style_v3, [])
|
||||||
|
|
||||||
|
|
||||||
|
dict_recraft_substyles_v3 = {
|
||||||
|
RecraftStyleV3.realistic_image: [
|
||||||
|
"b_and_w",
|
||||||
|
"enterprise",
|
||||||
|
"evening_light",
|
||||||
|
"faded_nostalgia",
|
||||||
|
"forest_life",
|
||||||
|
"hard_flash",
|
||||||
|
"hdr",
|
||||||
|
"motion_blur",
|
||||||
|
"mystic_naturalism",
|
||||||
|
"natural_light",
|
||||||
|
"natural_tones",
|
||||||
|
"organic_calm",
|
||||||
|
"real_life_glow",
|
||||||
|
"retro_realism",
|
||||||
|
"retro_snapshot",
|
||||||
|
"studio_portrait",
|
||||||
|
"urban_drama",
|
||||||
|
"village_realism",
|
||||||
|
"warm_folk"
|
||||||
|
],
|
||||||
|
RecraftStyleV3.digital_illustration: [
|
||||||
|
"2d_art_poster",
|
||||||
|
"2d_art_poster_2",
|
||||||
|
"antiquarian",
|
||||||
|
"bold_fantasy",
|
||||||
|
"child_book",
|
||||||
|
"child_books",
|
||||||
|
"cover",
|
||||||
|
"crosshatch",
|
||||||
|
"digital_engraving",
|
||||||
|
"engraving_color",
|
||||||
|
"expressionism",
|
||||||
|
"freehand_details",
|
||||||
|
"grain",
|
||||||
|
"grain_20",
|
||||||
|
"graphic_intensity",
|
||||||
|
"hand_drawn",
|
||||||
|
"hand_drawn_outline",
|
||||||
|
"handmade_3d",
|
||||||
|
"hard_comics",
|
||||||
|
"infantile_sketch",
|
||||||
|
"long_shadow",
|
||||||
|
"modern_folk",
|
||||||
|
"multicolor",
|
||||||
|
"neon_calm",
|
||||||
|
"noir",
|
||||||
|
"nostalgic_pastel",
|
||||||
|
"outline_details",
|
||||||
|
"pastel_gradient",
|
||||||
|
"pastel_sketch",
|
||||||
|
"pixel_art",
|
||||||
|
"plastic",
|
||||||
|
"pop_art",
|
||||||
|
"pop_renaissance",
|
||||||
|
"seamless",
|
||||||
|
"street_art",
|
||||||
|
"tablet_sketch",
|
||||||
|
"urban_glow",
|
||||||
|
"urban_sketching",
|
||||||
|
"vanilla_dreams",
|
||||||
|
"young_adult_book",
|
||||||
|
"young_adult_book_2"
|
||||||
|
],
|
||||||
|
RecraftStyleV3.vector_illustration: [
|
||||||
|
"bold_stroke",
|
||||||
|
"chemistry",
|
||||||
|
"colored_stencil",
|
||||||
|
"contour_pop_art",
|
||||||
|
"cosmics",
|
||||||
|
"cutout",
|
||||||
|
"depressive",
|
||||||
|
"editorial",
|
||||||
|
"emotional_flat",
|
||||||
|
"engraving",
|
||||||
|
"infographical",
|
||||||
|
"line_art",
|
||||||
|
"line_circuit",
|
||||||
|
"linocut",
|
||||||
|
"marker_outline",
|
||||||
|
"mosaic",
|
||||||
|
"naivector",
|
||||||
|
"roundish_flat",
|
||||||
|
"seamless",
|
||||||
|
"segmented_colors",
|
||||||
|
"sharp_contrast",
|
||||||
|
"thin",
|
||||||
|
"vector_photo",
|
||||||
|
"vivid_shapes"
|
||||||
|
],
|
||||||
|
RecraftStyleV3.logo_raster: [
|
||||||
|
"emblem_graffiti",
|
||||||
|
"emblem_pop_art",
|
||||||
|
"emblem_punk",
|
||||||
|
"emblem_stamp",
|
||||||
|
"emblem_vintage"
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftModel(str, Enum):
|
||||||
|
recraftv3 = 'recraftv3'
|
||||||
|
recraftv2 = 'recraftv2'
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftImageSize(str, Enum):
|
||||||
|
res_1024x1024 = '1024x1024'
|
||||||
|
res_1365x1024 = '1365x1024'
|
||||||
|
res_1024x1365 = '1024x1365'
|
||||||
|
res_1536x1024 = '1536x1024'
|
||||||
|
res_1024x1536 = '1024x1536'
|
||||||
|
res_1820x1024 = '1820x1024'
|
||||||
|
res_1024x1820 = '1024x1820'
|
||||||
|
res_1024x2048 = '1024x2048'
|
||||||
|
res_2048x1024 = '2048x1024'
|
||||||
|
res_1434x1024 = '1434x1024'
|
||||||
|
res_1024x1434 = '1024x1434'
|
||||||
|
res_1024x1280 = '1024x1280'
|
||||||
|
res_1280x1024 = '1280x1024'
|
||||||
|
res_1024x1707 = '1024x1707'
|
||||||
|
res_1707x1024 = '1707x1024'
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftColorObject(BaseModel):
|
||||||
|
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftControlsObject(BaseModel):
|
||||||
|
colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
|
||||||
|
background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
|
||||||
|
no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
|
||||||
|
artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftImageGenerationRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||||
|
size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||||
|
n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
|
||||||
|
negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
|
||||||
|
model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||||
|
style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||||
|
substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||||
|
controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
|
||||||
|
style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||||
|
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||||
|
random_seed: Optional[int] = Field(None, description="Seed for video generation")
|
||||||
|
# text_layout
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftReturnedObject(BaseModel):
|
||||||
|
image_id: str = Field(..., description='Unique identifier for the generated image')
|
||||||
|
url: str = Field(..., description='URL to access the generated image')
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftImageGenerationResponse(BaseModel):
|
||||||
|
created: int = Field(..., description='Unix timestamp when the generation was created')
|
||||||
|
credits: int = Field(..., description='Number of credits used for the generation')
|
||||||
|
data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
|
||||||
|
image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')
|
||||||
127
comfy_api_nodes/apis/stability_api.py
Normal file
127
comfy_api_nodes/apis/stability_api.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, confloat
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityFormat(str, Enum):
|
||||||
|
png = 'png'
|
||||||
|
jpeg = 'jpeg'
|
||||||
|
webp = 'webp'
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityAspectRatio(str, Enum):
|
||||||
|
ratio_1_1 = "1:1"
|
||||||
|
ratio_16_9 = "16:9"
|
||||||
|
ratio_9_16 = "9:16"
|
||||||
|
ratio_3_2 = "3:2"
|
||||||
|
ratio_2_3 = "2:3"
|
||||||
|
ratio_5_4 = "5:4"
|
||||||
|
ratio_4_5 = "4:5"
|
||||||
|
ratio_21_9 = "21:9"
|
||||||
|
ratio_9_21 = "9:21"
|
||||||
|
|
||||||
|
|
||||||
|
def get_stability_style_presets(include_none=True):
|
||||||
|
presets = []
|
||||||
|
if include_none:
|
||||||
|
presets.append("None")
|
||||||
|
return presets + [x.value for x in StabilityStylePreset]
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityStylePreset(str, Enum):
|
||||||
|
_3d_model = "3d-model"
|
||||||
|
analog_film = "analog-film"
|
||||||
|
anime = "anime"
|
||||||
|
cinematic = "cinematic"
|
||||||
|
comic_book = "comic-book"
|
||||||
|
digital_art = "digital-art"
|
||||||
|
enhance = "enhance"
|
||||||
|
fantasy_art = "fantasy-art"
|
||||||
|
isometric = "isometric"
|
||||||
|
line_art = "line-art"
|
||||||
|
low_poly = "low-poly"
|
||||||
|
modeling_compound = "modeling-compound"
|
||||||
|
neon_punk = "neon-punk"
|
||||||
|
origami = "origami"
|
||||||
|
photographic = "photographic"
|
||||||
|
pixel_art = "pixel-art"
|
||||||
|
tile_texture = "tile-texture"
|
||||||
|
|
||||||
|
|
||||||
|
class Stability_SD3_5_Model(str, Enum):
|
||||||
|
sd3_5_large = "sd3.5-large"
|
||||||
|
# sd3_5_large_turbo = "sd3.5-large-turbo"
|
||||||
|
sd3_5_medium = "sd3.5-medium"
|
||||||
|
|
||||||
|
|
||||||
|
class Stability_SD3_5_GenerationMode(str, Enum):
|
||||||
|
text_to_image = "text-to-image"
|
||||||
|
image_to_image = "image-to-image"
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityStable3_5Request(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
mode: str = Field(...)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
aspect_ratio: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||||
|
image: Optional[str] = Field(None)
|
||||||
|
style_preset: Optional[str] = Field(None)
|
||||||
|
cfg_scale: float = Field(...)
|
||||||
|
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityUpscaleConservativeRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||||
|
image: Optional[str] = Field(None)
|
||||||
|
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityUpscaleCreativeRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||||
|
image: Optional[str] = Field(None)
|
||||||
|
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
|
||||||
|
style_preset: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityStableUltraRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
aspect_ratio: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||||
|
image: Optional[str] = Field(None)
|
||||||
|
style_preset: Optional[str] = Field(None)
|
||||||
|
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityStableUltraResponse(BaseModel):
|
||||||
|
image: Optional[str] = Field(None)
|
||||||
|
finish_reason: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityResultsGetResponse(BaseModel):
|
||||||
|
image: Optional[str] = Field(None)
|
||||||
|
finish_reason: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
id: Optional[str] = Field(None)
|
||||||
|
name: Optional[str] = Field(None)
|
||||||
|
errors: Optional[list[str]] = Field(None)
|
||||||
|
status: Optional[str] = Field(None)
|
||||||
|
result: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityAsyncResponse(BaseModel):
|
||||||
|
id: Optional[str] = Field(None)
|
||||||
116
comfy_api_nodes/mapper_utils.py
Normal file
116
comfy_api_nodes/mapper_utils.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
|
from comfy.comfy_types.node_typing import IO, InputTypeOptions
|
||||||
|
|
||||||
|
NodeInput = tuple[IO, InputTypeOptions]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_base_config(field_info: FieldInfo) -> InputTypeOptions:
|
||||||
|
config = {}
|
||||||
|
if hasattr(field_info, "default") and field_info.default is not PydanticUndefined:
|
||||||
|
config["default"] = field_info.default
|
||||||
|
if hasattr(field_info, "description") and field_info.description is not None:
|
||||||
|
config["tooltip"] = field_info.description
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _get_number_constraints_config(field_info: FieldInfo) -> dict:
|
||||||
|
config = {}
|
||||||
|
if hasattr(field_info, "metadata"):
|
||||||
|
metadata = field_info.metadata
|
||||||
|
for constraint in metadata:
|
||||||
|
if hasattr(constraint, "ge"):
|
||||||
|
config["min"] = constraint.ge
|
||||||
|
if hasattr(constraint, "le"):
|
||||||
|
config["max"] = constraint.le
|
||||||
|
if hasattr(constraint, "multiple_of"):
|
||||||
|
config["step"] = constraint.multiple_of
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||||
|
return IO.IMAGE, {
|
||||||
|
**_create_base_config(field_info),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||||
|
return IO.STRING, {
|
||||||
|
**_create_base_config(field_info),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||||
|
return IO.FLOAT, {
|
||||||
|
**_create_base_config(field_info),
|
||||||
|
**_get_number_constraints_config(field_info),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||||
|
return IO.INT, {
|
||||||
|
**_create_base_config(field_info),
|
||||||
|
**_get_number_constraints_config(field_info),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _model_field_to_combo_input(
|
||||||
|
field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs
|
||||||
|
) -> NodeInput:
|
||||||
|
combo_config = {}
|
||||||
|
if enum_type is not None:
|
||||||
|
combo_config["options"] = [option.value for option in enum_type]
|
||||||
|
combo_config = {
|
||||||
|
**combo_config,
|
||||||
|
**_create_base_config(field_info),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
return IO.COMBO, combo_config
|
||||||
|
|
||||||
|
|
||||||
|
def model_field_to_node_input(
|
||||||
|
input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs
|
||||||
|
) -> NodeInput:
|
||||||
|
"""
|
||||||
|
Maps a field from a Pydantic model to a Comfy node input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_type: The type of the input.
|
||||||
|
base_model: The Pydantic model to map the field from.
|
||||||
|
field_name: The name of the field to map.
|
||||||
|
**kwargs: Additional key/values to include in the input options.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True)
|
||||||
|
>>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum)
|
||||||
|
>>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True)
|
||||||
|
"""
|
||||||
|
field_info: FieldInfo = base_model.model_fields[field_name]
|
||||||
|
result: NodeInput
|
||||||
|
|
||||||
|
if input_type == IO.IMAGE:
|
||||||
|
result = _model_field_to_image_input(field_info, **kwargs)
|
||||||
|
elif input_type == IO.STRING:
|
||||||
|
result = _model_field_to_string_input(field_info, **kwargs)
|
||||||
|
elif input_type == IO.FLOAT:
|
||||||
|
result = _model_field_to_float_input(field_info, **kwargs)
|
||||||
|
elif input_type == IO.INT:
|
||||||
|
result = _model_field_to_int_input(field_info, **kwargs)
|
||||||
|
elif input_type == IO.COMBO:
|
||||||
|
result = _model_field_to_combo_input(field_info, **kwargs)
|
||||||
|
else:
|
||||||
|
message = f"Invalid input type: {input_type}"
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
|
return result
|
||||||
906
comfy_api_nodes/nodes_bfl.py
Normal file
906
comfy_api_nodes/nodes_bfl.py
Normal file
@@ -0,0 +1,906 @@
|
|||||||
|
import io
|
||||||
|
from inspect import cleandoc
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||||
|
from comfy_api_nodes.apis.bfl_api import (
|
||||||
|
BFLStatus,
|
||||||
|
BFLFluxExpandImageRequest,
|
||||||
|
BFLFluxFillImageRequest,
|
||||||
|
BFLFluxCannyImageRequest,
|
||||||
|
BFLFluxDepthImageRequest,
|
||||||
|
BFLFluxProGenerateRequest,
|
||||||
|
BFLFluxProUltraGenerateRequest,
|
||||||
|
BFLFluxProGenerateResponse,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
downscale_image_tensor,
|
||||||
|
validate_aspect_ratio,
|
||||||
|
process_image_response,
|
||||||
|
resize_mask_to_image,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mask_to_image(mask: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||||
|
"""
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
mask = torch.cat([mask]*3, dim=-1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def handle_bfl_synchronous_operation(
|
||||||
|
operation: SynchronousOperation, timeout_bfl_calls=360
|
||||||
|
):
|
||||||
|
response_api: BFLFluxProGenerateResponse = operation.execute()
|
||||||
|
return _poll_until_generated(
|
||||||
|
response_api.polling_url, timeout=timeout_bfl_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
def _poll_until_generated(polling_url: str, timeout=360):
|
||||||
|
# used bfl-comfy-nodes to verify code implementation:
|
||||||
|
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
|
||||||
|
start_time = time.time()
|
||||||
|
retries_404 = 0
|
||||||
|
max_retries_404 = 5
|
||||||
|
retry_404_seconds = 2
|
||||||
|
retry_202_seconds = 2
|
||||||
|
retry_pending_seconds = 1
|
||||||
|
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
||||||
|
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||||
|
while True:
|
||||||
|
response = requests.Session().send(request.prepare())
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
if result["status"] == BFLStatus.ready:
|
||||||
|
img_url = result["result"]["sample"]
|
||||||
|
img_response = requests.get(img_url)
|
||||||
|
return process_image_response(img_response)
|
||||||
|
elif result["status"] in [
|
||||||
|
BFLStatus.request_moderated,
|
||||||
|
BFLStatus.content_moderated,
|
||||||
|
]:
|
||||||
|
status = result["status"]
|
||||||
|
raise Exception(
|
||||||
|
f"BFL API did not return an image due to: {status}."
|
||||||
|
)
|
||||||
|
elif result["status"] == BFLStatus.error:
|
||||||
|
raise Exception(f"BFL API encountered an error: {result}.")
|
||||||
|
elif result["status"] == BFLStatus.pending:
|
||||||
|
time.sleep(retry_pending_seconds)
|
||||||
|
continue
|
||||||
|
elif response.status_code == 404:
|
||||||
|
if retries_404 < max_retries_404:
|
||||||
|
retries_404 += 1
|
||||||
|
time.sleep(retry_404_seconds)
|
||||||
|
continue
|
||||||
|
raise Exception(
|
||||||
|
f"BFL API could not find task after {max_retries_404} tries."
|
||||||
|
)
|
||||||
|
elif response.status_code == 202:
|
||||||
|
time.sleep(retry_202_seconds)
|
||||||
|
elif time.time() - start_time > timeout:
|
||||||
|
raise Exception(
|
||||||
|
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||||
|
|
||||||
|
def convert_image_to_base64(image: torch.Tensor):
|
||||||
|
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||||
|
# remove batch dimension if present
|
||||||
|
if len(scaled_image.shape) > 3:
|
||||||
|
scaled_image = scaled_image[0]
|
||||||
|
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(image_np)
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
img.save(img_byte_arr, format="PNG")
|
||||||
|
return base64.b64encode(img_byte_arr.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxProUltraImageNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MINIMUM_RATIO = 1 / 4
|
||||||
|
MAXIMUM_RATIO = 4 / 1
|
||||||
|
MINIMUM_RATIO_STR = "1:4"
|
||||||
|
MAXIMUM_RATIO_STR = "4:1"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_upsampling": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"aspect_ratio": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "16:9",
|
||||||
|
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"raw": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "When True, generate less processed, more natural-looking images.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image_prompt": (IO.IMAGE,),
|
||||||
|
"image_prompt_strength": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.1,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Blend between the prompt and the image prompt.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||||
|
try:
|
||||||
|
validate_aspect_ratio(
|
||||||
|
aspect_ratio,
|
||||||
|
minimum_ratio=cls.MINIMUM_RATIO,
|
||||||
|
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||||
|
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||||
|
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return str(e)
|
||||||
|
return True
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
prompt_upsampling=False,
|
||||||
|
raw=False,
|
||||||
|
seed=0,
|
||||||
|
image_prompt=None,
|
||||||
|
image_prompt_strength=0.1,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if image_prompt is None:
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/bfl/flux-pro-1.1-ultra/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=BFLFluxProUltraGenerateRequest,
|
||||||
|
response_model=BFLFluxProGenerateResponse,
|
||||||
|
),
|
||||||
|
request=BFLFluxProUltraGenerateRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_upsampling=prompt_upsampling,
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=validate_aspect_ratio(
|
||||||
|
aspect_ratio,
|
||||||
|
minimum_ratio=self.MINIMUM_RATIO,
|
||||||
|
maximum_ratio=self.MAXIMUM_RATIO,
|
||||||
|
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||||
|
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||||
|
),
|
||||||
|
raw=raw,
|
||||||
|
image_prompt=(
|
||||||
|
image_prompt
|
||||||
|
if image_prompt is None
|
||||||
|
else convert_image_to_base64(image_prompt)
|
||||||
|
),
|
||||||
|
image_prompt_strength=(
|
||||||
|
None if image_prompt is None else round(image_prompt_strength, 2)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
output_image = handle_bfl_synchronous_operation(operation)
|
||||||
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxProImageNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously based on prompt and resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_upsampling": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"width": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1024,
|
||||||
|
"min": 256,
|
||||||
|
"max": 1440,
|
||||||
|
"step": 32,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"height": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 768,
|
||||||
|
"min": 256,
|
||||||
|
"max": 1440,
|
||||||
|
"step": 32,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image_prompt": (IO.IMAGE,),
|
||||||
|
# "image_prompt_strength": (
|
||||||
|
# IO.FLOAT,
|
||||||
|
# {
|
||||||
|
# "default": 0.1,
|
||||||
|
# "min": 0.0,
|
||||||
|
# "max": 1.0,
|
||||||
|
# "step": 0.01,
|
||||||
|
# "tooltip": "Blend between the prompt and the image prompt.",
|
||||||
|
# },
|
||||||
|
# ),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
prompt_upsampling,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
seed=0,
|
||||||
|
image_prompt=None,
|
||||||
|
# image_prompt_strength=0.1,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
image_prompt = (
|
||||||
|
image_prompt
|
||||||
|
if image_prompt is None
|
||||||
|
else convert_image_to_base64(image_prompt)
|
||||||
|
)
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/bfl/flux-pro-1.1/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=BFLFluxProGenerateRequest,
|
||||||
|
response_model=BFLFluxProGenerateResponse,
|
||||||
|
),
|
||||||
|
request=BFLFluxProGenerateRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_upsampling=prompt_upsampling,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
seed=seed,
|
||||||
|
image_prompt=image_prompt,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
output_image = handle_bfl_synchronous_operation(operation)
|
||||||
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxProExpandNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Outpaints image based on prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_upsampling": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"top": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2048,
|
||||||
|
"tooltip": "Number of pixels to expand at the top of the image"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"bottom": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2048,
|
||||||
|
"tooltip": "Number of pixels to expand at the bottom of the image"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"left": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2048,
|
||||||
|
"tooltip": "Number of pixels to expand at the left side of the image"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"right": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2048,
|
||||||
|
"tooltip": "Number of pixels to expand at the right side of the image"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"guidance": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 60,
|
||||||
|
"min": 1.5,
|
||||||
|
"max": 100,
|
||||||
|
"tooltip": "Guidance strength for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 50,
|
||||||
|
"min": 15,
|
||||||
|
"max": 50,
|
||||||
|
"tooltip": "Number of steps for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
prompt_upsampling: bool,
|
||||||
|
top: int,
|
||||||
|
bottom: int,
|
||||||
|
left: int,
|
||||||
|
right: int,
|
||||||
|
steps: int,
|
||||||
|
guidance: float,
|
||||||
|
seed=0,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
image = convert_image_to_base64(image)
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/bfl/flux-pro-1.0-expand/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=BFLFluxExpandImageRequest,
|
||||||
|
response_model=BFLFluxProGenerateResponse,
|
||||||
|
),
|
||||||
|
request=BFLFluxExpandImageRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_upsampling=prompt_upsampling,
|
||||||
|
top=top,
|
||||||
|
bottom=bottom,
|
||||||
|
left=left,
|
||||||
|
right=right,
|
||||||
|
steps=steps,
|
||||||
|
guidance=guidance,
|
||||||
|
seed=seed,
|
||||||
|
image=image,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
output_image = handle_bfl_synchronous_operation(operation)
|
||||||
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxProFillNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Inpaints image based on mask and prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
"mask": (IO.MASK,),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_upsampling": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"guidance": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 60,
|
||||||
|
"min": 1.5,
|
||||||
|
"max": 100,
|
||||||
|
"tooltip": "Guidance strength for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 50,
|
||||||
|
"min": 15,
|
||||||
|
"max": 50,
|
||||||
|
"tooltip": "Number of steps for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
prompt_upsampling: bool,
|
||||||
|
steps: int,
|
||||||
|
guidance: float,
|
||||||
|
seed=0,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# prepare mask
|
||||||
|
mask = resize_mask_to_image(mask, image)
|
||||||
|
mask = convert_image_to_base64(convert_mask_to_image(mask))
|
||||||
|
# make sure image will have alpha channel removed
|
||||||
|
image = convert_image_to_base64(image[:,:,:,:3])
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/bfl/flux-pro-1.0-fill/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=BFLFluxFillImageRequest,
|
||||||
|
response_model=BFLFluxProGenerateResponse,
|
||||||
|
),
|
||||||
|
request=BFLFluxFillImageRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_upsampling=prompt_upsampling,
|
||||||
|
steps=steps,
|
||||||
|
guidance=guidance,
|
||||||
|
seed=seed,
|
||||||
|
image=image,
|
||||||
|
mask=mask,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
output_image = handle_bfl_synchronous_operation(operation)
|
||||||
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxProCannyNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generate image using a control image (canny).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"control_image": (IO.IMAGE,),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_upsampling": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"canny_low_threshold": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.1,
|
||||||
|
"min": 0.01,
|
||||||
|
"max": 0.99,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"canny_high_threshold": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.4,
|
||||||
|
"min": 0.01,
|
||||||
|
"max": 0.99,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"skip_preprocessing": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"guidance": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 30,
|
||||||
|
"min": 1,
|
||||||
|
"max": 100,
|
||||||
|
"tooltip": "Guidance strength for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 50,
|
||||||
|
"min": 15,
|
||||||
|
"max": 50,
|
||||||
|
"tooltip": "Number of steps for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
control_image: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
prompt_upsampling: bool,
|
||||||
|
canny_low_threshold: float,
|
||||||
|
canny_high_threshold: float,
|
||||||
|
skip_preprocessing: bool,
|
||||||
|
steps: int,
|
||||||
|
guidance: float,
|
||||||
|
seed=0,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||||
|
preprocessed_image = None
|
||||||
|
|
||||||
|
# scale canny threshold between 0-500, to match BFL's API
|
||||||
|
def scale_value(value: float, min_val=0, max_val=500):
|
||||||
|
return min_val + value * (max_val - min_val)
|
||||||
|
canny_low_threshold = int(round(scale_value(canny_low_threshold)))
|
||||||
|
canny_high_threshold = int(round(scale_value(canny_high_threshold)))
|
||||||
|
|
||||||
|
|
||||||
|
if skip_preprocessing:
|
||||||
|
preprocessed_image = control_image
|
||||||
|
control_image = None
|
||||||
|
canny_low_threshold = None
|
||||||
|
canny_high_threshold = None
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/bfl/flux-pro-1.0-canny/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=BFLFluxCannyImageRequest,
|
||||||
|
response_model=BFLFluxProGenerateResponse,
|
||||||
|
),
|
||||||
|
request=BFLFluxCannyImageRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_upsampling=prompt_upsampling,
|
||||||
|
steps=steps,
|
||||||
|
guidance=guidance,
|
||||||
|
seed=seed,
|
||||||
|
control_image=control_image,
|
||||||
|
canny_low_threshold=canny_low_threshold,
|
||||||
|
canny_high_threshold=canny_high_threshold,
|
||||||
|
preprocessed_image=preprocessed_image,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
output_image = handle_bfl_synchronous_operation(operation)
|
||||||
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxProDepthNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generate image using a control image (depth).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"control_image": (IO.IMAGE,),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_upsampling": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"skip_preprocessing": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"guidance": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 15,
|
||||||
|
"min": 1,
|
||||||
|
"max": 100,
|
||||||
|
"tooltip": "Guidance strength for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 50,
|
||||||
|
"min": 15,
|
||||||
|
"max": 50,
|
||||||
|
"tooltip": "Number of steps for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
control_image: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
prompt_upsampling: bool,
|
||||||
|
skip_preprocessing: bool,
|
||||||
|
steps: int,
|
||||||
|
guidance: float,
|
||||||
|
seed=0,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||||
|
preprocessed_image = None
|
||||||
|
|
||||||
|
if skip_preprocessing:
|
||||||
|
preprocessed_image = control_image
|
||||||
|
control_image = None
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/bfl/flux-pro-1.0-depth/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=BFLFluxDepthImageRequest,
|
||||||
|
response_model=BFLFluxProGenerateResponse,
|
||||||
|
),
|
||||||
|
request=BFLFluxDepthImageRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_upsampling=prompt_upsampling,
|
||||||
|
steps=steps,
|
||||||
|
guidance=guidance,
|
||||||
|
seed=seed,
|
||||||
|
control_image=control_image,
|
||||||
|
preprocessed_image=preprocessed_image,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
output_image = handle_bfl_synchronous_operation(operation)
|
||||||
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
# A dictionary that contains all nodes you want to export with their names
|
||||||
|
# NOTE: names should be globally unique
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"FluxProUltraImageNode": FluxProUltraImageNode,
|
||||||
|
# "FluxProImageNode": FluxProImageNode,
|
||||||
|
"FluxProExpandNode": FluxProExpandNode,
|
||||||
|
"FluxProFillNode": FluxProFillNode,
|
||||||
|
"FluxProCannyNode": FluxProCannyNode,
|
||||||
|
"FluxProDepthNode": FluxProDepthNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
||||||
|
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
||||||
|
"FluxProExpandNode": "Flux.1 Expand Image",
|
||||||
|
"FluxProFillNode": "Flux.1 Fill Image",
|
||||||
|
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
||||||
|
"FluxProDepthNode": "Flux.1 Depth Control Image",
|
||||||
|
}
|
||||||
777
comfy_api_nodes/nodes_ideogram.py
Normal file
777
comfy_api_nodes/nodes_ideogram.py
Normal file
@@ -0,0 +1,777 @@
|
|||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
from inspect import cleandoc
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import io
|
||||||
|
import torch
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
IdeogramGenerateRequest,
|
||||||
|
IdeogramGenerateResponse,
|
||||||
|
ImageRequest,
|
||||||
|
IdeogramV3Request,
|
||||||
|
IdeogramV3EditRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
)
|
||||||
|
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
download_url_to_bytesio,
|
||||||
|
bytesio_to_image_tensor,
|
||||||
|
resize_mask_to_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
V1_V1_RES_MAP = {
|
||||||
|
"Auto":"AUTO",
|
||||||
|
"512 x 1536":"RESOLUTION_512_1536",
|
||||||
|
"576 x 1408":"RESOLUTION_576_1408",
|
||||||
|
"576 x 1472":"RESOLUTION_576_1472",
|
||||||
|
"576 x 1536":"RESOLUTION_576_1536",
|
||||||
|
"640 x 1024":"RESOLUTION_640_1024",
|
||||||
|
"640 x 1344":"RESOLUTION_640_1344",
|
||||||
|
"640 x 1408":"RESOLUTION_640_1408",
|
||||||
|
"640 x 1472":"RESOLUTION_640_1472",
|
||||||
|
"640 x 1536":"RESOLUTION_640_1536",
|
||||||
|
"704 x 1152":"RESOLUTION_704_1152",
|
||||||
|
"704 x 1216":"RESOLUTION_704_1216",
|
||||||
|
"704 x 1280":"RESOLUTION_704_1280",
|
||||||
|
"704 x 1344":"RESOLUTION_704_1344",
|
||||||
|
"704 x 1408":"RESOLUTION_704_1408",
|
||||||
|
"704 x 1472":"RESOLUTION_704_1472",
|
||||||
|
"720 x 1280":"RESOLUTION_720_1280",
|
||||||
|
"736 x 1312":"RESOLUTION_736_1312",
|
||||||
|
"768 x 1024":"RESOLUTION_768_1024",
|
||||||
|
"768 x 1088":"RESOLUTION_768_1088",
|
||||||
|
"768 x 1152":"RESOLUTION_768_1152",
|
||||||
|
"768 x 1216":"RESOLUTION_768_1216",
|
||||||
|
"768 x 1232":"RESOLUTION_768_1232",
|
||||||
|
"768 x 1280":"RESOLUTION_768_1280",
|
||||||
|
"768 x 1344":"RESOLUTION_768_1344",
|
||||||
|
"832 x 960":"RESOLUTION_832_960",
|
||||||
|
"832 x 1024":"RESOLUTION_832_1024",
|
||||||
|
"832 x 1088":"RESOLUTION_832_1088",
|
||||||
|
"832 x 1152":"RESOLUTION_832_1152",
|
||||||
|
"832 x 1216":"RESOLUTION_832_1216",
|
||||||
|
"832 x 1248":"RESOLUTION_832_1248",
|
||||||
|
"864 x 1152":"RESOLUTION_864_1152",
|
||||||
|
"896 x 960":"RESOLUTION_896_960",
|
||||||
|
"896 x 1024":"RESOLUTION_896_1024",
|
||||||
|
"896 x 1088":"RESOLUTION_896_1088",
|
||||||
|
"896 x 1120":"RESOLUTION_896_1120",
|
||||||
|
"896 x 1152":"RESOLUTION_896_1152",
|
||||||
|
"960 x 832":"RESOLUTION_960_832",
|
||||||
|
"960 x 896":"RESOLUTION_960_896",
|
||||||
|
"960 x 1024":"RESOLUTION_960_1024",
|
||||||
|
"960 x 1088":"RESOLUTION_960_1088",
|
||||||
|
"1024 x 640":"RESOLUTION_1024_640",
|
||||||
|
"1024 x 768":"RESOLUTION_1024_768",
|
||||||
|
"1024 x 832":"RESOLUTION_1024_832",
|
||||||
|
"1024 x 896":"RESOLUTION_1024_896",
|
||||||
|
"1024 x 960":"RESOLUTION_1024_960",
|
||||||
|
"1024 x 1024":"RESOLUTION_1024_1024",
|
||||||
|
"1088 x 768":"RESOLUTION_1088_768",
|
||||||
|
"1088 x 832":"RESOLUTION_1088_832",
|
||||||
|
"1088 x 896":"RESOLUTION_1088_896",
|
||||||
|
"1088 x 960":"RESOLUTION_1088_960",
|
||||||
|
"1120 x 896":"RESOLUTION_1120_896",
|
||||||
|
"1152 x 704":"RESOLUTION_1152_704",
|
||||||
|
"1152 x 768":"RESOLUTION_1152_768",
|
||||||
|
"1152 x 832":"RESOLUTION_1152_832",
|
||||||
|
"1152 x 864":"RESOLUTION_1152_864",
|
||||||
|
"1152 x 896":"RESOLUTION_1152_896",
|
||||||
|
"1216 x 704":"RESOLUTION_1216_704",
|
||||||
|
"1216 x 768":"RESOLUTION_1216_768",
|
||||||
|
"1216 x 832":"RESOLUTION_1216_832",
|
||||||
|
"1232 x 768":"RESOLUTION_1232_768",
|
||||||
|
"1248 x 832":"RESOLUTION_1248_832",
|
||||||
|
"1280 x 704":"RESOLUTION_1280_704",
|
||||||
|
"1280 x 720":"RESOLUTION_1280_720",
|
||||||
|
"1280 x 768":"RESOLUTION_1280_768",
|
||||||
|
"1280 x 800":"RESOLUTION_1280_800",
|
||||||
|
"1312 x 736":"RESOLUTION_1312_736",
|
||||||
|
"1344 x 640":"RESOLUTION_1344_640",
|
||||||
|
"1344 x 704":"RESOLUTION_1344_704",
|
||||||
|
"1344 x 768":"RESOLUTION_1344_768",
|
||||||
|
"1408 x 576":"RESOLUTION_1408_576",
|
||||||
|
"1408 x 640":"RESOLUTION_1408_640",
|
||||||
|
"1408 x 704":"RESOLUTION_1408_704",
|
||||||
|
"1472 x 576":"RESOLUTION_1472_576",
|
||||||
|
"1472 x 640":"RESOLUTION_1472_640",
|
||||||
|
"1472 x 704":"RESOLUTION_1472_704",
|
||||||
|
"1536 x 512":"RESOLUTION_1536_512",
|
||||||
|
"1536 x 576":"RESOLUTION_1536_576",
|
||||||
|
"1536 x 640":"RESOLUTION_1536_640",
|
||||||
|
}
|
||||||
|
|
||||||
|
V1_V2_RATIO_MAP = {
|
||||||
|
"1:1":"ASPECT_1_1",
|
||||||
|
"4:3":"ASPECT_4_3",
|
||||||
|
"3:4":"ASPECT_3_4",
|
||||||
|
"16:9":"ASPECT_16_9",
|
||||||
|
"9:16":"ASPECT_9_16",
|
||||||
|
"2:1":"ASPECT_2_1",
|
||||||
|
"1:2":"ASPECT_1_2",
|
||||||
|
"3:2":"ASPECT_3_2",
|
||||||
|
"2:3":"ASPECT_2_3",
|
||||||
|
"4:5":"ASPECT_4_5",
|
||||||
|
"5:4":"ASPECT_5_4",
|
||||||
|
}
|
||||||
|
|
||||||
|
V3_RATIO_MAP = {
|
||||||
|
"1:3":"1x3",
|
||||||
|
"3:1":"3x1",
|
||||||
|
"1:2":"1x2",
|
||||||
|
"2:1":"2x1",
|
||||||
|
"9:16":"9x16",
|
||||||
|
"16:9":"16x9",
|
||||||
|
"10:16":"10x16",
|
||||||
|
"16:10":"16x10",
|
||||||
|
"2:3":"2x3",
|
||||||
|
"3:2":"3x2",
|
||||||
|
"3:4":"3x4",
|
||||||
|
"4:3":"4x3",
|
||||||
|
"4:5":"4x5",
|
||||||
|
"5:4":"5x4",
|
||||||
|
"1:1":"1x1",
|
||||||
|
}
|
||||||
|
|
||||||
|
V3_RESOLUTIONS= [
|
||||||
|
"Auto",
|
||||||
|
"512x1536",
|
||||||
|
"576x1408",
|
||||||
|
"576x1472",
|
||||||
|
"576x1536",
|
||||||
|
"640x1344",
|
||||||
|
"640x1408",
|
||||||
|
"640x1472",
|
||||||
|
"640x1536",
|
||||||
|
"704x1152",
|
||||||
|
"704x1216",
|
||||||
|
"704x1280",
|
||||||
|
"704x1344",
|
||||||
|
"704x1408",
|
||||||
|
"704x1472",
|
||||||
|
"736x1312",
|
||||||
|
"768x1088",
|
||||||
|
"768x1216",
|
||||||
|
"768x1280",
|
||||||
|
"768x1344",
|
||||||
|
"800x1280",
|
||||||
|
"832x960",
|
||||||
|
"832x1024",
|
||||||
|
"832x1088",
|
||||||
|
"832x1152",
|
||||||
|
"832x1216",
|
||||||
|
"832x1248",
|
||||||
|
"864x1152",
|
||||||
|
"896x960",
|
||||||
|
"896x1024",
|
||||||
|
"896x1088",
|
||||||
|
"896x1120",
|
||||||
|
"896x1152",
|
||||||
|
"960x832",
|
||||||
|
"960x896",
|
||||||
|
"960x1024",
|
||||||
|
"960x1088",
|
||||||
|
"1024x832",
|
||||||
|
"1024x896",
|
||||||
|
"1024x960",
|
||||||
|
"1024x1024",
|
||||||
|
"1088x768",
|
||||||
|
"1088x832",
|
||||||
|
"1088x896",
|
||||||
|
"1088x960",
|
||||||
|
"1120x896",
|
||||||
|
"1152x704",
|
||||||
|
"1152x832",
|
||||||
|
"1152x864",
|
||||||
|
"1152x896",
|
||||||
|
"1216x704",
|
||||||
|
"1216x768",
|
||||||
|
"1216x832",
|
||||||
|
"1248x832",
|
||||||
|
"1280x704",
|
||||||
|
"1280x768",
|
||||||
|
"1280x800",
|
||||||
|
"1312x736",
|
||||||
|
"1344x640",
|
||||||
|
"1344x704",
|
||||||
|
"1344x768",
|
||||||
|
"1408x576",
|
||||||
|
"1408x640",
|
||||||
|
"1408x704",
|
||||||
|
"1472x576",
|
||||||
|
"1472x640",
|
||||||
|
"1472x704",
|
||||||
|
"1536x512",
|
||||||
|
"1536x576",
|
||||||
|
"1536x640"
|
||||||
|
]
|
||||||
|
|
||||||
|
def download_and_process_images(image_urls):
|
||||||
|
"""Helper function to download and process multiple images from URLs"""
|
||||||
|
|
||||||
|
# Initialize list to store image tensors
|
||||||
|
image_tensors = []
|
||||||
|
|
||||||
|
for image_url in image_urls:
|
||||||
|
# Using functions from apinode_utils.py to handle downloading and processing
|
||||||
|
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||||
|
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||||
|
image_tensors.append(img_tensor)
|
||||||
|
|
||||||
|
# Stack tensors to match (N, width, height, channels)
|
||||||
|
if image_tensors:
|
||||||
|
stacked_tensors = torch.cat(image_tensors, dim=0)
|
||||||
|
else:
|
||||||
|
raise Exception("No valid images were processed")
|
||||||
|
|
||||||
|
return stacked_tensors
|
||||||
|
|
||||||
|
|
||||||
|
class IdeogramV1(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously using the Ideogram V1 model.
|
||||||
|
|
||||||
|
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"turbo": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"aspect_ratio": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||||
|
"default": "1:1",
|
||||||
|
"tooltip": "The aspect ratio for image generation.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"magic_prompt_option": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["AUTO", "ON", "OFF"],
|
||||||
|
"default": "AUTO",
|
||||||
|
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2147483647,
|
||||||
|
"step": 1,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"display": "number",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Description of what to exclude from the image",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"num_images": (
|
||||||
|
IO.INT,
|
||||||
|
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/image/Ideogram/v1"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
turbo=False,
|
||||||
|
aspect_ratio="1:1",
|
||||||
|
magic_prompt_option="AUTO",
|
||||||
|
seed=0,
|
||||||
|
negative_prompt="",
|
||||||
|
num_images=1,
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
# Determine the model based on turbo setting
|
||||||
|
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||||
|
model = "V_1_TURBO" if turbo else "V_1"
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/ideogram/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=IdeogramGenerateRequest,
|
||||||
|
response_model=IdeogramGenerateResponse,
|
||||||
|
),
|
||||||
|
request=IdeogramGenerateRequest(
|
||||||
|
image_request=ImageRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
num_images=num_images,
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||||
|
magic_prompt_option=(
|
||||||
|
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||||
|
),
|
||||||
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
if not response.data or len(response.data) == 0:
|
||||||
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
|
if not image_urls:
|
||||||
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
|
return (download_and_process_images(image_urls),)
|
||||||
|
|
||||||
|
|
||||||
|
class IdeogramV2(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously using the Ideogram V2 model.
|
||||||
|
|
||||||
|
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"turbo": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"aspect_ratio": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||||
|
"default": "1:1",
|
||||||
|
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"resolution": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": list(V1_V1_RES_MAP.keys()),
|
||||||
|
"default": "Auto",
|
||||||
|
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"magic_prompt_option": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["AUTO", "ON", "OFF"],
|
||||||
|
"default": "AUTO",
|
||||||
|
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2147483647,
|
||||||
|
"step": 1,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"display": "number",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"style_type": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||||
|
"default": "NONE",
|
||||||
|
"tooltip": "Style type for generation (V2 only)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Description of what to exclude from the image",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"num_images": (
|
||||||
|
IO.INT,
|
||||||
|
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||||
|
),
|
||||||
|
#"color_palette": (
|
||||||
|
# IO.STRING,
|
||||||
|
# {
|
||||||
|
# "multiline": False,
|
||||||
|
# "default": "",
|
||||||
|
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||||
|
# },
|
||||||
|
#),
|
||||||
|
},
|
||||||
|
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/image/Ideogram/v2"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
turbo=False,
|
||||||
|
aspect_ratio="1:1",
|
||||||
|
resolution="Auto",
|
||||||
|
magic_prompt_option="AUTO",
|
||||||
|
seed=0,
|
||||||
|
style_type="NONE",
|
||||||
|
negative_prompt="",
|
||||||
|
num_images=1,
|
||||||
|
color_palette="",
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||||
|
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||||
|
# Determine the model based on turbo setting
|
||||||
|
model = "V_2_TURBO" if turbo else "V_2"
|
||||||
|
|
||||||
|
# Handle resolution vs aspect_ratio logic
|
||||||
|
# If resolution is not AUTO, it overrides aspect_ratio
|
||||||
|
final_resolution = None
|
||||||
|
final_aspect_ratio = None
|
||||||
|
|
||||||
|
if resolution != "AUTO":
|
||||||
|
final_resolution = resolution
|
||||||
|
else:
|
||||||
|
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/ideogram/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=IdeogramGenerateRequest,
|
||||||
|
response_model=IdeogramGenerateResponse,
|
||||||
|
),
|
||||||
|
request=IdeogramGenerateRequest(
|
||||||
|
image_request=ImageRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
num_images=num_images,
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=final_aspect_ratio,
|
||||||
|
resolution=final_resolution,
|
||||||
|
magic_prompt_option=(
|
||||||
|
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||||
|
),
|
||||||
|
style_type=style_type if style_type != "NONE" else None,
|
||||||
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
|
color_palette=color_palette if color_palette else None,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
if not response.data or len(response.data) == 0:
|
||||||
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
|
if not image_urls:
|
||||||
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
|
return (download_and_process_images(image_urls),)
|
||||||
|
|
||||||
|
class IdeogramV3(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously using the Ideogram V3 model.
|
||||||
|
|
||||||
|
Supports both regular image generation from text prompts and image editing with mask.
|
||||||
|
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation or editing",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional reference image for image editing.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"mask": (
|
||||||
|
IO.MASK,
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"aspect_ratio": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": list(V3_RATIO_MAP.keys()),
|
||||||
|
"default": "1:1",
|
||||||
|
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"resolution": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": V3_RESOLUTIONS,
|
||||||
|
"default": "Auto",
|
||||||
|
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"magic_prompt_option": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["AUTO", "ON", "OFF"],
|
||||||
|
"default": "AUTO",
|
||||||
|
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2147483647,
|
||||||
|
"step": 1,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"display": "number",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"num_images": (
|
||||||
|
IO.INT,
|
||||||
|
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||||
|
),
|
||||||
|
"rendering_speed": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["BALANCED", "TURBO", "QUALITY"],
|
||||||
|
"default": "BALANCED",
|
||||||
|
"tooltip": "Controls the trade-off between generation speed and quality",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/image/Ideogram/v3"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
image=None,
|
||||||
|
mask=None,
|
||||||
|
resolution="Auto",
|
||||||
|
aspect_ratio="1:1",
|
||||||
|
magic_prompt_option="AUTO",
|
||||||
|
seed=0,
|
||||||
|
num_images=1,
|
||||||
|
rendering_speed="BALANCED",
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
# Check if both image and mask are provided for editing mode
|
||||||
|
if image is not None and mask is not None:
|
||||||
|
# Edit mode
|
||||||
|
path = "/proxy/ideogram/ideogram-v3/edit"
|
||||||
|
|
||||||
|
# Process image and mask
|
||||||
|
input_tensor = image.squeeze().cpu()
|
||||||
|
# Resize mask to match image dimension
|
||||||
|
mask = resize_mask_to_image(mask, image, allow_gradient=False)
|
||||||
|
# Invert mask, as Ideogram API will edit black areas instead of white areas (opposite of convention).
|
||||||
|
mask = 1.0 - mask
|
||||||
|
|
||||||
|
# Validate mask dimensions match image
|
||||||
|
if mask.shape[1:] != image.shape[1:-1]:
|
||||||
|
raise Exception("Mask and Image must be the same size")
|
||||||
|
|
||||||
|
# Process image
|
||||||
|
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(img_np)
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
img.save(img_byte_arr, format="PNG")
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
img_binary = img_byte_arr
|
||||||
|
img_binary.name = "image.png"
|
||||||
|
|
||||||
|
# Process mask - white areas will be replaced
|
||||||
|
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
mask_img = Image.fromarray(mask_np)
|
||||||
|
mask_byte_arr = io.BytesIO()
|
||||||
|
mask_img.save(mask_byte_arr, format="PNG")
|
||||||
|
mask_byte_arr.seek(0)
|
||||||
|
mask_binary = mask_byte_arr
|
||||||
|
mask_binary.name = "mask.png"
|
||||||
|
|
||||||
|
# Create edit request
|
||||||
|
edit_request = IdeogramV3EditRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
rendering_speed=rendering_speed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if magic_prompt_option != "AUTO":
|
||||||
|
edit_request.magic_prompt = magic_prompt_option
|
||||||
|
if seed != 0:
|
||||||
|
edit_request.seed = seed
|
||||||
|
if num_images > 1:
|
||||||
|
edit_request.num_images = num_images
|
||||||
|
|
||||||
|
# Execute the operation for edit mode
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=IdeogramV3EditRequest,
|
||||||
|
response_model=IdeogramGenerateResponse,
|
||||||
|
),
|
||||||
|
request=edit_request,
|
||||||
|
files={
|
||||||
|
"image": img_binary,
|
||||||
|
"mask": mask_binary,
|
||||||
|
},
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif image is not None or mask is not None:
|
||||||
|
# If only one of image or mask is provided, raise an error
|
||||||
|
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||||
|
else:
|
||||||
|
# Generation mode
|
||||||
|
path = "/proxy/ideogram/ideogram-v3/generate"
|
||||||
|
|
||||||
|
# Create generation request
|
||||||
|
gen_request = IdeogramV3Request(
|
||||||
|
prompt=prompt,
|
||||||
|
rendering_speed=rendering_speed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle resolution vs aspect ratio
|
||||||
|
if resolution != "Auto":
|
||||||
|
gen_request.resolution = resolution
|
||||||
|
elif aspect_ratio != "1:1":
|
||||||
|
v3_aspect = V3_RATIO_MAP.get(aspect_ratio)
|
||||||
|
if v3_aspect:
|
||||||
|
gen_request.aspect_ratio = v3_aspect
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if magic_prompt_option != "AUTO":
|
||||||
|
gen_request.magic_prompt = magic_prompt_option
|
||||||
|
if seed != 0:
|
||||||
|
gen_request.seed = seed
|
||||||
|
if num_images > 1:
|
||||||
|
gen_request.num_images = num_images
|
||||||
|
|
||||||
|
# Execute the operation for generation mode
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=IdeogramV3Request,
|
||||||
|
response_model=IdeogramGenerateResponse,
|
||||||
|
),
|
||||||
|
request=gen_request,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the operation and process response
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
if not response.data or len(response.data) == 0:
|
||||||
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
|
if not image_urls:
|
||||||
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
|
return (download_and_process_images(image_urls),)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"IdeogramV1": IdeogramV1,
|
||||||
|
"IdeogramV2": IdeogramV2,
|
||||||
|
"IdeogramV3": IdeogramV3,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"IdeogramV1": "Ideogram V1",
|
||||||
|
"IdeogramV2": "Ideogram V2",
|
||||||
|
"IdeogramV3": "Ideogram V3",
|
||||||
|
}
|
||||||
|
|
||||||
1563
comfy_api_nodes/nodes_kling.py
Normal file
1563
comfy_api_nodes/nodes_kling.py
Normal file
File diff suppressed because it is too large
Load Diff
702
comfy_api_nodes/nodes_luma.py
Normal file
702
comfy_api_nodes/nodes_luma.py
Normal file
@@ -0,0 +1,702 @@
|
|||||||
|
from inspect import cleandoc
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||||
|
from comfy_api.input_impl.video_types import VideoFromFile
|
||||||
|
from comfy_api_nodes.apis.luma_api import (
|
||||||
|
LumaImageModel,
|
||||||
|
LumaVideoModel,
|
||||||
|
LumaVideoOutputResolution,
|
||||||
|
LumaVideoModelOutputDuration,
|
||||||
|
LumaAspectRatio,
|
||||||
|
LumaState,
|
||||||
|
LumaImageGenerationRequest,
|
||||||
|
LumaGenerationRequest,
|
||||||
|
LumaGeneration,
|
||||||
|
LumaCharacterRef,
|
||||||
|
LumaModifyImageRef,
|
||||||
|
LumaImageIdentity,
|
||||||
|
LumaReference,
|
||||||
|
LumaReferenceChain,
|
||||||
|
LumaImageReference,
|
||||||
|
LumaKeyframes,
|
||||||
|
LumaConceptChain,
|
||||||
|
LumaIO,
|
||||||
|
get_luma_concepts,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
process_image_response,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
class LumaReferenceNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Holds an image and weight for use with Luma Generate Image node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (LumaIO.LUMA_REF,)
|
||||||
|
RETURN_NAMES = ("luma_ref",)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "create_luma_reference"
|
||||||
|
CATEGORY = "api node/image/Luma"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"tooltip": "Image to use as reference.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"weight": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Weight of image reference.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_luma_reference(
|
||||||
|
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
||||||
|
):
|
||||||
|
if luma_ref is not None:
|
||||||
|
luma_ref = luma_ref.clone()
|
||||||
|
else:
|
||||||
|
luma_ref = LumaReferenceChain()
|
||||||
|
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
|
||||||
|
return (luma_ref,)
|
||||||
|
|
||||||
|
|
||||||
|
class LumaConceptsNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
|
||||||
|
RETURN_NAMES = ("luma_concepts",)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "create_concepts"
|
||||||
|
CATEGORY = "api node/video/Luma"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"concept1": (get_luma_concepts(include_none=True),),
|
||||||
|
"concept2": (get_luma_concepts(include_none=True),),
|
||||||
|
"concept3": (get_luma_concepts(include_none=True),),
|
||||||
|
"concept4": (get_luma_concepts(include_none=True),),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"luma_concepts": (
|
||||||
|
LumaIO.LUMA_CONCEPTS,
|
||||||
|
{
|
||||||
|
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_concepts(
|
||||||
|
self,
|
||||||
|
concept1: str,
|
||||||
|
concept2: str,
|
||||||
|
concept3: str,
|
||||||
|
concept4: str,
|
||||||
|
luma_concepts: LumaConceptChain = None,
|
||||||
|
):
|
||||||
|
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
|
||||||
|
if luma_concepts is not None:
|
||||||
|
chain = luma_concepts.clone_and_merge(chain)
|
||||||
|
return (chain,)
|
||||||
|
|
||||||
|
|
||||||
|
class LumaImageGenerationNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously based on prompt and aspect ratio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/Luma"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": ([model.value for model in LumaImageModel],),
|
||||||
|
"aspect_ratio": (
|
||||||
|
[ratio.value for ratio in LumaAspectRatio],
|
||||||
|
{
|
||||||
|
"default": LumaAspectRatio.ratio_16_9,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"style_image_weight": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Weight of style image. Ignored if no style_image provided.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image_luma_ref": (
|
||||||
|
LumaIO.LUMA_REF,
|
||||||
|
{
|
||||||
|
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"style_image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "Style reference image; only 1 image will be used."},
|
||||||
|
),
|
||||||
|
"character_image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
seed,
|
||||||
|
style_image_weight: float,
|
||||||
|
image_luma_ref: LumaReferenceChain = None,
|
||||||
|
style_image: torch.Tensor = None,
|
||||||
|
character_image: torch.Tensor = None,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||||
|
# handle image_luma_ref
|
||||||
|
api_image_ref = None
|
||||||
|
if image_luma_ref is not None:
|
||||||
|
api_image_ref = self._convert_luma_refs(
|
||||||
|
image_luma_ref, max_refs=4, auth_token=auth_token
|
||||||
|
)
|
||||||
|
# handle style_luma_ref
|
||||||
|
api_style_ref = None
|
||||||
|
if style_image is not None:
|
||||||
|
api_style_ref = self._convert_style_image(
|
||||||
|
style_image, weight=style_image_weight, auth_token=auth_token
|
||||||
|
)
|
||||||
|
# handle character_ref images
|
||||||
|
character_ref = None
|
||||||
|
if character_image is not None:
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
character_image, max_images=4, auth_token=auth_token
|
||||||
|
)
|
||||||
|
character_ref = LumaCharacterRef(
|
||||||
|
identity0=LumaImageIdentity(images=download_urls)
|
||||||
|
)
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/luma/generations/image",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=LumaImageGenerationRequest,
|
||||||
|
response_model=LumaGeneration,
|
||||||
|
),
|
||||||
|
request=LumaImageGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
image_ref=api_image_ref,
|
||||||
|
style_ref=api_style_ref,
|
||||||
|
character_ref=character_ref,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api: LumaGeneration = operation.execute()
|
||||||
|
|
||||||
|
operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/luma/generations/{response_api.id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=LumaGeneration,
|
||||||
|
),
|
||||||
|
completed_statuses=[LumaState.completed],
|
||||||
|
failed_statuses=[LumaState.failed],
|
||||||
|
status_extractor=lambda x: x.state,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_poll = operation.execute()
|
||||||
|
|
||||||
|
img_response = requests.get(response_poll.assets.image)
|
||||||
|
img = process_image_response(img_response)
|
||||||
|
return (img,)
|
||||||
|
|
||||||
|
def _convert_luma_refs(
|
||||||
|
self, luma_ref: LumaReferenceChain, max_refs: int, auth_token=None
|
||||||
|
):
|
||||||
|
luma_urls = []
|
||||||
|
ref_count = 0
|
||||||
|
for ref in luma_ref.refs:
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
ref.image, max_images=1, auth_token=auth_token
|
||||||
|
)
|
||||||
|
luma_urls.append(download_urls[0])
|
||||||
|
ref_count += 1
|
||||||
|
if ref_count >= max_refs:
|
||||||
|
break
|
||||||
|
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||||
|
|
||||||
|
def _convert_style_image(
|
||||||
|
self, style_image: torch.Tensor, weight: float, auth_token=None
|
||||||
|
):
|
||||||
|
chain = LumaReferenceChain(
|
||||||
|
first_ref=LumaReference(image=style_image, weight=weight)
|
||||||
|
)
|
||||||
|
return self._convert_luma_refs(chain, max_refs=1, auth_token=auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
class LumaImageModifyNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Modifies images synchronously based on prompt and aspect ratio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/Luma"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"image_weight": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.1,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 0.98,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": ([model.value for model in LumaImageModel],),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
image: torch.Tensor,
|
||||||
|
image_weight: float,
|
||||||
|
seed,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# first, upload image
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
image, max_images=1, auth_token=auth_token
|
||||||
|
)
|
||||||
|
image_url = download_urls[0]
|
||||||
|
# next, make Luma call with download url provided
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/luma/generations/image",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=LumaImageGenerationRequest,
|
||||||
|
response_model=LumaGeneration,
|
||||||
|
),
|
||||||
|
request=LumaImageGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
modify_image_ref=LumaModifyImageRef(
|
||||||
|
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api: LumaGeneration = operation.execute()
|
||||||
|
|
||||||
|
operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/luma/generations/{response_api.id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=LumaGeneration,
|
||||||
|
),
|
||||||
|
completed_statuses=[LumaState.completed],
|
||||||
|
failed_statuses=[LumaState.failed],
|
||||||
|
status_extractor=lambda x: x.state,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_poll = operation.execute()
|
||||||
|
|
||||||
|
img_response = requests.get(response_poll.assets.image)
|
||||||
|
img = process_image_response(img_response)
|
||||||
|
return (img,)
|
||||||
|
|
||||||
|
|
||||||
|
class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on prompt and output_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.VIDEO,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/video/Luma"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": ([model.value for model in LumaVideoModel],),
|
||||||
|
"aspect_ratio": (
|
||||||
|
[ratio.value for ratio in LumaAspectRatio],
|
||||||
|
{
|
||||||
|
"default": LumaAspectRatio.ratio_16_9,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"resolution": (
|
||||||
|
[resolution.value for resolution in LumaVideoOutputResolution],
|
||||||
|
{
|
||||||
|
"default": LumaVideoOutputResolution.res_540p,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
|
||||||
|
"loop": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"luma_concepts": (
|
||||||
|
LumaIO.LUMA_CONCEPTS,
|
||||||
|
{
|
||||||
|
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
resolution: str,
|
||||||
|
duration: str,
|
||||||
|
loop: bool,
|
||||||
|
seed,
|
||||||
|
luma_concepts: LumaConceptChain = None,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||||
|
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/luma/generations",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=LumaGenerationRequest,
|
||||||
|
response_model=LumaGeneration,
|
||||||
|
),
|
||||||
|
request=LumaGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
resolution=resolution,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=duration,
|
||||||
|
loop=loop,
|
||||||
|
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api: LumaGeneration = operation.execute()
|
||||||
|
|
||||||
|
operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/luma/generations/{response_api.id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=LumaGeneration,
|
||||||
|
),
|
||||||
|
completed_statuses=[LumaState.completed],
|
||||||
|
failed_statuses=[LumaState.failed],
|
||||||
|
status_extractor=lambda x: x.state,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_poll = operation.execute()
|
||||||
|
|
||||||
|
vid_response = requests.get(response_poll.assets.video)
|
||||||
|
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||||
|
|
||||||
|
|
||||||
|
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on prompt, input images, and output_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.VIDEO,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/video/Luma"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": ([model.value for model in LumaVideoModel],),
|
||||||
|
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
|
||||||
|
# "default": LumaAspectRatio.ratio_16_9,
|
||||||
|
# }),
|
||||||
|
"resolution": (
|
||||||
|
[resolution.value for resolution in LumaVideoOutputResolution],
|
||||||
|
{
|
||||||
|
"default": LumaVideoOutputResolution.res_540p,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
|
||||||
|
"loop": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"first_image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "First frame of generated video."},
|
||||||
|
),
|
||||||
|
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
|
||||||
|
"luma_concepts": (
|
||||||
|
LumaIO.LUMA_CONCEPTS,
|
||||||
|
{
|
||||||
|
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
resolution: str,
|
||||||
|
duration: str,
|
||||||
|
loop: bool,
|
||||||
|
seed,
|
||||||
|
first_image: torch.Tensor = None,
|
||||||
|
last_image: torch.Tensor = None,
|
||||||
|
luma_concepts: LumaConceptChain = None,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if first_image is None and last_image is None:
|
||||||
|
raise Exception(
|
||||||
|
"At least one of first_image and last_image requires an input."
|
||||||
|
)
|
||||||
|
keyframes = self._convert_to_keyframes(first_image, last_image, auth_token)
|
||||||
|
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/luma/generations",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=LumaGenerationRequest,
|
||||||
|
response_model=LumaGeneration,
|
||||||
|
),
|
||||||
|
request=LumaGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
loop=loop,
|
||||||
|
keyframes=keyframes,
|
||||||
|
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api: LumaGeneration = operation.execute()
|
||||||
|
|
||||||
|
operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/luma/generations/{response_api.id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=LumaGeneration,
|
||||||
|
),
|
||||||
|
completed_statuses=[LumaState.completed],
|
||||||
|
failed_statuses=[LumaState.failed],
|
||||||
|
status_extractor=lambda x: x.state,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_poll = operation.execute()
|
||||||
|
|
||||||
|
vid_response = requests.get(response_poll.assets.video)
|
||||||
|
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||||
|
|
||||||
|
def _convert_to_keyframes(
|
||||||
|
self,
|
||||||
|
first_image: torch.Tensor = None,
|
||||||
|
last_image: torch.Tensor = None,
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
if first_image is None and last_image is None:
|
||||||
|
return None
|
||||||
|
frame0 = None
|
||||||
|
frame1 = None
|
||||||
|
if first_image is not None:
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
first_image, max_images=1, auth_token=auth_token
|
||||||
|
)
|
||||||
|
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||||
|
if last_image is not None:
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
last_image, max_images=1, auth_token=auth_token
|
||||||
|
)
|
||||||
|
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||||
|
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||||
|
|
||||||
|
|
||||||
|
# A dictionary that contains all nodes you want to export with their names
|
||||||
|
# NOTE: names should be globally unique
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LumaImageNode": LumaImageGenerationNode,
|
||||||
|
"LumaImageModifyNode": LumaImageModifyNode,
|
||||||
|
"LumaVideoNode": LumaTextToVideoGenerationNode,
|
||||||
|
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
|
||||||
|
"LumaReferenceNode": LumaReferenceNode,
|
||||||
|
"LumaConceptsNode": LumaConceptsNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"LumaImageNode": "Luma Text to Image",
|
||||||
|
"LumaImageModifyNode": "Luma Image to Image",
|
||||||
|
"LumaVideoNode": "Luma Text to Video",
|
||||||
|
"LumaImageToVideoNode": "Luma Image to Video",
|
||||||
|
"LumaReferenceNode": "Luma Reference",
|
||||||
|
"LumaConceptsNode": "Luma Concepts",
|
||||||
|
}
|
||||||
306
comfy_api_nodes/nodes_minimax.py
Normal file
306
comfy_api_nodes/nodes_minimax.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
from comfy_api.input_impl.video_types import VideoFromFile
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
MinimaxVideoGenerationRequest,
|
||||||
|
MinimaxVideoGenerationResponse,
|
||||||
|
MinimaxFileRetrieveResponse,
|
||||||
|
MinimaxTaskResultResponse,
|
||||||
|
SubjectReferenceItem,
|
||||||
|
Model
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
download_url_to_bytesio,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxTextToVideoNode:
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt_text": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt to guide the video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": (
|
||||||
|
[
|
||||||
|
"T2V-01",
|
||||||
|
"T2V-01-Director",
|
||||||
|
],
|
||||||
|
{
|
||||||
|
"default": "T2V-01",
|
||||||
|
"tooltip": "Model to use for video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
DESCRIPTION = "Generates videos from prompts using MiniMax's API"
|
||||||
|
FUNCTION = "generate_video"
|
||||||
|
CATEGORY = "api node/video/MiniMax"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def generate_video(
|
||||||
|
self,
|
||||||
|
prompt_text,
|
||||||
|
seed=0,
|
||||||
|
model="T2V-01",
|
||||||
|
image: torch.Tensor=None, # used for ImageToVideo
|
||||||
|
subject: torch.Tensor=None, # used for SubjectToVideo
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
|
||||||
|
'''
|
||||||
|
if image is None:
|
||||||
|
validate_string(prompt_text, field_name="prompt_text")
|
||||||
|
# upload image, if passed in
|
||||||
|
image_url = None
|
||||||
|
if image is not None:
|
||||||
|
image_url = upload_images_to_comfyapi(image, max_images=1, auth_token=auth_token)[0]
|
||||||
|
|
||||||
|
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||||
|
subject_reference = None
|
||||||
|
if subject is not None:
|
||||||
|
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_token=auth_token)[0]
|
||||||
|
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||||
|
|
||||||
|
|
||||||
|
video_generate_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/minimax/video_generation",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=MinimaxVideoGenerationRequest,
|
||||||
|
response_model=MinimaxVideoGenerationResponse,
|
||||||
|
),
|
||||||
|
request=MinimaxVideoGenerationRequest(
|
||||||
|
model=Model(model),
|
||||||
|
prompt=prompt_text,
|
||||||
|
callback_url=None,
|
||||||
|
first_frame_image=image_url,
|
||||||
|
subject_reference=subject_reference,
|
||||||
|
prompt_optimizer=None,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response = video_generate_operation.execute()
|
||||||
|
|
||||||
|
task_id = response.task_id
|
||||||
|
if not task_id:
|
||||||
|
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||||
|
|
||||||
|
video_generate_operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/minimax/query/video_generation",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=MinimaxTaskResultResponse,
|
||||||
|
query_params={"task_id": task_id},
|
||||||
|
),
|
||||||
|
completed_statuses=["Success"],
|
||||||
|
failed_statuses=["Fail"],
|
||||||
|
status_extractor=lambda x: x.status.value,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
task_result = video_generate_operation.execute()
|
||||||
|
|
||||||
|
file_id = task_result.file_id
|
||||||
|
if file_id is None:
|
||||||
|
raise Exception("Request was not successful. Missing file ID.")
|
||||||
|
file_retrieve_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/minimax/files/retrieve",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=MinimaxFileRetrieveResponse,
|
||||||
|
query_params={"file_id": int(file_id)},
|
||||||
|
),
|
||||||
|
request=EmptyRequest(),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
file_result = file_retrieve_operation.execute()
|
||||||
|
|
||||||
|
file_url = file_result.file.download_url
|
||||||
|
if file_url is None:
|
||||||
|
raise Exception(
|
||||||
|
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||||
|
)
|
||||||
|
logging.info(f"Generated video URL: {file_url}")
|
||||||
|
|
||||||
|
video_io = download_url_to_bytesio(file_url)
|
||||||
|
if video_io is None:
|
||||||
|
error_msg = f"Failed to download video from {file_url}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
return (VideoFromFile(video_io),)
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"tooltip": "Image to use as first frame of video generation"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_text": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt to guide the video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": (
|
||||||
|
[
|
||||||
|
"I2V-01-Director",
|
||||||
|
"I2V-01",
|
||||||
|
"I2V-01-live",
|
||||||
|
],
|
||||||
|
{
|
||||||
|
"default": "I2V-01",
|
||||||
|
"tooltip": "Model to use for video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
|
||||||
|
FUNCTION = "generate_video"
|
||||||
|
CATEGORY = "api node/video/MiniMax"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"subject": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"tooltip": "Image of subject to reference video generation"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_text": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt to guide the video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": (
|
||||||
|
[
|
||||||
|
"S2V-01",
|
||||||
|
],
|
||||||
|
{
|
||||||
|
"default": "S2V-01",
|
||||||
|
"tooltip": "Model to use for video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
|
||||||
|
FUNCTION = "generate_video"
|
||||||
|
CATEGORY = "api node/video/MiniMax"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
|
||||||
|
# A dictionary that contains all nodes you want to export with their names
|
||||||
|
# NOTE: names should be globally unique
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
|
||||||
|
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
|
||||||
|
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"MinimaxTextToVideoNode": "MiniMax Text to Video",
|
||||||
|
"MinimaxImageToVideoNode": "MiniMax Image to Video",
|
||||||
|
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
|
||||||
|
}
|
||||||
487
comfy_api_nodes/nodes_openai.py
Normal file
487
comfy_api_nodes/nodes_openai.py
Normal file
@@ -0,0 +1,487 @@
|
|||||||
|
import io
|
||||||
|
from inspect import cleandoc
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
|
||||||
|
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
OpenAIImageGenerationRequest,
|
||||||
|
OpenAIImageEditRequest,
|
||||||
|
OpenAIImageGenerationResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
)
|
||||||
|
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
downscale_image_tensor,
|
||||||
|
validate_and_cast_response,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
class OpenAIDalle2(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt for DALL·E",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2**31 - 1,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "not implemented yet in backend",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"size": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["256x256", "512x512", "1024x1024"],
|
||||||
|
"default": "1024x1024",
|
||||||
|
"tooltip": "Image size",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"n": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 8,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"tooltip": "How many images to generate",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional reference image for image editing.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"mask": (
|
||||||
|
IO.MASK,
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/image/OpenAI"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
seed=0,
|
||||||
|
image=None,
|
||||||
|
mask=None,
|
||||||
|
n=1,
|
||||||
|
size="1024x1024",
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
model = "dall-e-2"
|
||||||
|
path = "/proxy/openai/images/generations"
|
||||||
|
content_type = "application/json"
|
||||||
|
request_class = OpenAIImageGenerationRequest
|
||||||
|
img_binary = None
|
||||||
|
|
||||||
|
if image is not None and mask is not None:
|
||||||
|
path = "/proxy/openai/images/edits"
|
||||||
|
content_type = "multipart/form-data"
|
||||||
|
request_class = OpenAIImageEditRequest
|
||||||
|
|
||||||
|
input_tensor = image.squeeze().cpu()
|
||||||
|
height, width, channels = input_tensor.shape
|
||||||
|
rgba_tensor = torch.ones(height, width, 4, device="cpu")
|
||||||
|
rgba_tensor[:, :, :channels] = input_tensor
|
||||||
|
|
||||||
|
if mask.shape[1:] != image.shape[1:-1]:
|
||||||
|
raise Exception("Mask and Image must be the same size")
|
||||||
|
rgba_tensor[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||||
|
|
||||||
|
rgba_tensor = downscale_image_tensor(rgba_tensor.unsqueeze(0)).squeeze()
|
||||||
|
|
||||||
|
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(image_np)
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
img.save(img_byte_arr, format="PNG")
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
img_binary = img_byte_arr # .getvalue()
|
||||||
|
img_binary.name = "image.png"
|
||||||
|
elif image is not None or mask is not None:
|
||||||
|
raise Exception("Dall-E 2 image editing requires an image AND a mask")
|
||||||
|
|
||||||
|
# Build the operation
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=request_class,
|
||||||
|
response_model=OpenAIImageGenerationResponse,
|
||||||
|
),
|
||||||
|
request=request_class(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
n=n,
|
||||||
|
size=size,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
files=(
|
||||||
|
{
|
||||||
|
"image": img_binary,
|
||||||
|
}
|
||||||
|
if img_binary
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
content_type=content_type,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
img_tensor = validate_and_cast_response(response)
|
||||||
|
return (img_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIDalle3(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt for DALL·E",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2**31 - 1,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "not implemented yet in backend",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"quality": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["standard", "hd"],
|
||||||
|
"default": "standard",
|
||||||
|
"tooltip": "Image quality",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"style": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["natural", "vivid"],
|
||||||
|
"default": "natural",
|
||||||
|
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"size": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["1024x1024", "1024x1792", "1792x1024"],
|
||||||
|
"default": "1024x1024",
|
||||||
|
"tooltip": "Image size",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/image/OpenAI"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
seed=0,
|
||||||
|
style="natural",
|
||||||
|
quality="standard",
|
||||||
|
size="1024x1024",
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
model = "dall-e-3"
|
||||||
|
|
||||||
|
# build the operation
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/openai/images/generations",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=OpenAIImageGenerationRequest,
|
||||||
|
response_model=OpenAIImageGenerationResponse,
|
||||||
|
),
|
||||||
|
request=OpenAIImageGenerationRequest(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
size=size,
|
||||||
|
style=style,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
img_tensor = validate_and_cast_response(response)
|
||||||
|
return (img_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIGPTImage1(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt for GPT Image 1",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2**31 - 1,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "not implemented yet in backend",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"quality": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["low", "medium", "high"],
|
||||||
|
"default": "low",
|
||||||
|
"tooltip": "Image quality, affects cost and generation time.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"background": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["opaque", "transparent"],
|
||||||
|
"default": "opaque",
|
||||||
|
"tooltip": "Return image with or without background",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"size": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||||
|
"default": "auto",
|
||||||
|
"tooltip": "Image size",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"n": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 8,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"tooltip": "How many images to generate",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional reference image for image editing.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"mask": (
|
||||||
|
IO.MASK,
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/image/OpenAI"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
seed=0,
|
||||||
|
quality="low",
|
||||||
|
background="opaque",
|
||||||
|
image=None,
|
||||||
|
mask=None,
|
||||||
|
n=1,
|
||||||
|
size="1024x1024",
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
model = "gpt-image-1"
|
||||||
|
path = "/proxy/openai/images/generations"
|
||||||
|
content_type="application/json"
|
||||||
|
request_class = OpenAIImageGenerationRequest
|
||||||
|
img_binaries = []
|
||||||
|
mask_binary = None
|
||||||
|
files = []
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
path = "/proxy/openai/images/edits"
|
||||||
|
request_class = OpenAIImageEditRequest
|
||||||
|
content_type ="multipart/form-data"
|
||||||
|
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
single_image = image[i : i + 1]
|
||||||
|
scaled_image = downscale_image_tensor(single_image).squeeze()
|
||||||
|
|
||||||
|
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(image_np)
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
img.save(img_byte_arr, format="PNG")
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
img_binary = img_byte_arr
|
||||||
|
img_binary.name = f"image_{i}.png"
|
||||||
|
|
||||||
|
img_binaries.append(img_binary)
|
||||||
|
if batch_size == 1:
|
||||||
|
files.append(("image", img_binary))
|
||||||
|
else:
|
||||||
|
files.append(("image[]", img_binary))
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if image is None:
|
||||||
|
raise Exception("Cannot use a mask without an input image")
|
||||||
|
if image.shape[0] != 1:
|
||||||
|
raise Exception("Cannot use a mask with multiple image")
|
||||||
|
if mask.shape[1:] != image.shape[1:-1]:
|
||||||
|
raise Exception("Mask and Image must be the same size")
|
||||||
|
batch, height, width = mask.shape
|
||||||
|
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||||
|
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||||
|
|
||||||
|
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
|
||||||
|
|
||||||
|
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||||
|
mask_img = Image.fromarray(mask_np)
|
||||||
|
mask_img_byte_arr = io.BytesIO()
|
||||||
|
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||||
|
mask_img_byte_arr.seek(0)
|
||||||
|
mask_binary = mask_img_byte_arr
|
||||||
|
mask_binary.name = "mask.png"
|
||||||
|
files.append(("mask", mask_binary))
|
||||||
|
|
||||||
|
# Build the operation
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=request_class,
|
||||||
|
response_model=OpenAIImageGenerationResponse,
|
||||||
|
),
|
||||||
|
request=request_class(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
background=background,
|
||||||
|
n=n,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
),
|
||||||
|
files=files if files else None,
|
||||||
|
content_type=content_type,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
img_tensor = validate_and_cast_response(response)
|
||||||
|
return (img_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
# A dictionary that contains all nodes you want to export with their names
|
||||||
|
# NOTE: names should be globally unique
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"OpenAIDalle2": OpenAIDalle2,
|
||||||
|
"OpenAIDalle3": OpenAIDalle3,
|
||||||
|
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||||
|
}
|
||||||
|
|
||||||
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||||
|
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||||
|
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||||
|
}
|
||||||
749
comfy_api_nodes/nodes_pika.py
Normal file
749
comfy_api_nodes/nodes_pika.py
Normal file
@@ -0,0 +1,749 @@
|
|||||||
|
"""
|
||||||
|
Pika x ComfyUI API Nodes
|
||||||
|
|
||||||
|
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
from typing import Optional, TypeVar
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||||
|
PikaGenerateResponse,
|
||||||
|
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||||
|
PikaVideoResponse,
|
||||||
|
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||||
|
IngredientsMode,
|
||||||
|
PikaDurationEnum,
|
||||||
|
PikaResolutionEnum,
|
||||||
|
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
|
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
|
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||||
|
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||||
|
Pikaffect,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
tensor_to_bytesio,
|
||||||
|
download_url_to_video_output,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||||
|
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||||
|
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||||
|
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
|
||||||
|
|
||||||
|
PIKA_API_VERSION = "2.2"
|
||||||
|
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
|
||||||
|
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
|
||||||
|
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
|
||||||
|
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
||||||
|
|
||||||
|
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaApiError(Exception):
|
||||||
|
"""Exception for Pika API errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_video_response(response: PikaVideoResponse) -> bool:
|
||||||
|
"""Check if the video response is valid."""
|
||||||
|
return hasattr(response, "url") and response.url is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
|
||||||
|
"""Check if the initial response is valid."""
|
||||||
|
return hasattr(response, "video_id") and response.video_id is not None
|
||||||
|
|
||||||
|
|
||||||
|
class PikaNodeBase(ComfyNodeABC):
|
||||||
|
"""Base class for Pika nodes."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_base_inputs_types(
|
||||||
|
cls, request_model
|
||||||
|
) -> dict[str, tuple[IO, InputTypeOptions]]:
|
||||||
|
"""Get the base required inputs types common to all Pika nodes."""
|
||||||
|
return {
|
||||||
|
"prompt_text": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
request_model,
|
||||||
|
"promptText",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
"negative_prompt": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
request_model,
|
||||||
|
"negativePrompt",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
request_model,
|
||||||
|
"seed",
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
"resolution": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
request_model,
|
||||||
|
"resolution",
|
||||||
|
enum_type=PikaResolutionEnum,
|
||||||
|
),
|
||||||
|
"duration": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
request_model,
|
||||||
|
"duration",
|
||||||
|
enum_type=PikaDurationEnum,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "api node/video/Pika"
|
||||||
|
API_NODE = True
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
|
||||||
|
def poll_for_task_status(
|
||||||
|
self, task_id: str, auth_token: str
|
||||||
|
) -> PikaGenerateResponse:
|
||||||
|
polling_operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=PikaVideoResponse,
|
||||||
|
),
|
||||||
|
completed_statuses=[
|
||||||
|
"finished",
|
||||||
|
],
|
||||||
|
failed_statuses=["failed", "cancelled"],
|
||||||
|
status_extractor=lambda response: (
|
||||||
|
response.status.value if response.status else None
|
||||||
|
),
|
||||||
|
progress_extractor=lambda response: (
|
||||||
|
response.progress if hasattr(response, "progress") else None
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
return polling_operation.execute()
|
||||||
|
|
||||||
|
def execute_task(
|
||||||
|
self,
|
||||||
|
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
"""Executes the initial operation then polls for the task status until it is completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_operation: The initial operation to execute.
|
||||||
|
auth_token: The authentication token to use for the API call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the video file as a VIDEO output.
|
||||||
|
"""
|
||||||
|
initial_response = initial_operation.execute()
|
||||||
|
if not is_valid_initial_response(initial_response):
|
||||||
|
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise PikaApiError(error_msg)
|
||||||
|
|
||||||
|
task_id = initial_response.video_id
|
||||||
|
final_response = self.poll_for_task_status(task_id, auth_token)
|
||||||
|
if not is_valid_video_response(final_response):
|
||||||
|
error_msg = (
|
||||||
|
f"Pika task {task_id} succeeded but no video data found in response."
|
||||||
|
)
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise PikaApiError(error_msg)
|
||||||
|
|
||||||
|
video_url = str(final_response.url)
|
||||||
|
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||||
|
|
||||||
|
return (download_url_to_video_output(video_url),)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaImageToVideoV2_2(PikaNodeBase):
|
||||||
|
"""Pika 2.2 Image to Video Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "The image to convert to video"},
|
||||||
|
),
|
||||||
|
**cls.get_base_inputs_types(PikaBodyGenerate22I2vGenerate22I2vPost),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
duration: int,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
# Convert image to BytesIO
|
||||||
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
|
image_bytes_io.seek(0)
|
||||||
|
|
||||||
|
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||||
|
|
||||||
|
# Prepare non-file data
|
||||||
|
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_IMAGE_TO_VIDEO,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||||
|
response_model=PikaGenerateResponse,
|
||||||
|
),
|
||||||
|
request=pika_request_data,
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.execute_task(initial_operation, auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||||
|
"""Pika Text2Video v2.2 Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
**cls.get_base_inputs_types(PikaBodyGenerate22T2vGenerate22T2vPost),
|
||||||
|
"aspect_ratio": model_field_to_node_input(
|
||||||
|
IO.FLOAT,
|
||||||
|
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||||
|
"aspectRatio",
|
||||||
|
step=0.001,
|
||||||
|
min=0.4,
|
||||||
|
max=2.5,
|
||||||
|
default=1.7777777777777777,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
duration: int,
|
||||||
|
aspect_ratio: float,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_TEXT_TO_VIDEO,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||||
|
response_model=PikaGenerateResponse,
|
||||||
|
),
|
||||||
|
request=PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
aspectRatio=aspect_ratio,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
content_type="application/x-www-form-urlencoded",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.execute_task(initial_operation, auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaScenesV2_2(PikaNodeBase):
|
||||||
|
"""PikaScenes v2.2 Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
image_ingredient_input = (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "Image that will be used as ingredient to create a video."},
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
**cls.get_base_inputs_types(
|
||||||
|
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||||
|
),
|
||||||
|
"ingredients_mode": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||||
|
"ingredientsMode",
|
||||||
|
enum_type=IngredientsMode,
|
||||||
|
default="creative",
|
||||||
|
),
|
||||||
|
"aspect_ratio": model_field_to_node_input(
|
||||||
|
IO.FLOAT,
|
||||||
|
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||||
|
"aspectRatio",
|
||||||
|
step=0.001,
|
||||||
|
min=0.4,
|
||||||
|
max=2.5,
|
||||||
|
default=1.7777777777777777,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image_ingredient_1": image_ingredient_input,
|
||||||
|
"image_ingredient_2": image_ingredient_input,
|
||||||
|
"image_ingredient_3": image_ingredient_input,
|
||||||
|
"image_ingredient_4": image_ingredient_input,
|
||||||
|
"image_ingredient_5": image_ingredient_input,
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
duration: int,
|
||||||
|
ingredients_mode: str,
|
||||||
|
aspect_ratio: float,
|
||||||
|
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||||
|
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||||
|
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||||
|
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||||
|
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
# Convert all passed images to BytesIO
|
||||||
|
all_image_bytes_io = []
|
||||||
|
for image in [
|
||||||
|
image_ingredient_1,
|
||||||
|
image_ingredient_2,
|
||||||
|
image_ingredient_3,
|
||||||
|
image_ingredient_4,
|
||||||
|
image_ingredient_5,
|
||||||
|
]:
|
||||||
|
if image is not None:
|
||||||
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
|
image_bytes_io.seek(0)
|
||||||
|
all_image_bytes_io.append(image_bytes_io)
|
||||||
|
|
||||||
|
pika_files = [
|
||||||
|
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||||
|
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||||
|
]
|
||||||
|
|
||||||
|
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||||
|
ingredientsMode=ingredients_mode,
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
aspectRatio=aspect_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_PIKASCENES,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||||
|
response_model=PikaGenerateResponse,
|
||||||
|
),
|
||||||
|
request=pika_request_data,
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.execute_task(initial_operation, auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
class PikAdditionsNode(PikaNodeBase):
|
||||||
|
"""Pika Pikadditions Node. Add an image into a video."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"video": (IO.VIDEO, {"tooltip": "The video to add an image to."}),
|
||||||
|
"image": (IO.IMAGE, {"tooltip": "The image to add to the video."}),
|
||||||
|
"prompt_text": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
|
"promptText",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
"negative_prompt": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
|
"negativePrompt",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
|
"seed",
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result."
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
video: VideoInput,
|
||||||
|
image: torch.Tensor,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
# Convert video to BytesIO
|
||||||
|
video_bytes_io = io.BytesIO()
|
||||||
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
|
# Convert image to BytesIO
|
||||||
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
|
image_bytes_io.seek(0)
|
||||||
|
|
||||||
|
pika_files = [
|
||||||
|
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||||
|
("image", ("image.png", image_bytes_io, "image/png")),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Prepare non-file data
|
||||||
|
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_PIKADDITIONS,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
|
response_model=PikaGenerateResponse,
|
||||||
|
),
|
||||||
|
request=pika_request_data,
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.execute_task(initial_operation, auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaSwapsNode(PikaNodeBase):
|
||||||
|
"""Pika Pikaswaps Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"video": (IO.VIDEO, {"tooltip": "The video to swap an object in."}),
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"tooltip": "The image used to replace the masked object in the video."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"mask": (
|
||||||
|
IO.MASK,
|
||||||
|
{"tooltip": "Use the mask to define areas in the video to replace"},
|
||||||
|
),
|
||||||
|
"prompt_text": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||||
|
"promptText",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
"negative_prompt": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||||
|
"negativePrompt",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||||
|
"seed",
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
|
||||||
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
video: VideoInput,
|
||||||
|
image: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
# Convert video to BytesIO
|
||||||
|
video_bytes_io = io.BytesIO()
|
||||||
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
|
# Convert mask to binary mask with three channels
|
||||||
|
mask = torch.round(mask)
|
||||||
|
mask = mask.repeat(1, 3, 1, 1)
|
||||||
|
|
||||||
|
# Convert 3-channel binary mask to BytesIO
|
||||||
|
mask_bytes_io = io.BytesIO()
|
||||||
|
mask_bytes_io.write(mask.numpy().astype(np.uint8))
|
||||||
|
mask_bytes_io.seek(0)
|
||||||
|
|
||||||
|
# Convert image to BytesIO
|
||||||
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
|
image_bytes_io.seek(0)
|
||||||
|
|
||||||
|
pika_files = [
|
||||||
|
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||||
|
("image", ("image.png", image_bytes_io, "image/png")),
|
||||||
|
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Prepare non-file data
|
||||||
|
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_PIKADDITIONS,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
|
response_model=PikaGenerateResponse,
|
||||||
|
),
|
||||||
|
request=pika_request_data,
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.execute_task(initial_operation, auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaffectsNode(PikaNodeBase):
|
||||||
|
"""Pika Pikaffects Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "The reference image to apply the Pikaffect to."},
|
||||||
|
),
|
||||||
|
"pikaffect": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
|
"pikaffect",
|
||||||
|
enum_type=Pikaffect,
|
||||||
|
default="Cake-ify",
|
||||||
|
),
|
||||||
|
"prompt_text": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
|
"promptText",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
"negative_prompt": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
|
"negativePrompt",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
|
"seed",
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
pikaffect: str,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_PIKAFFECTS,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
|
response_model=PikaGenerateResponse,
|
||||||
|
),
|
||||||
|
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||||
|
pikaffect=pikaffect,
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.execute_task(initial_operation, auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||||
|
"""PikaFrames v2.2 Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image_start": (IO.IMAGE, {"tooltip": "The first image to combine."}),
|
||||||
|
"image_end": (IO.IMAGE, {"tooltip": "The last image to combine."}),
|
||||||
|
**cls.get_base_inputs_types(
|
||||||
|
PikaBodyGenerate22KeyframeGenerate22PikaframesPost
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
image_start: torch.Tensor,
|
||||||
|
image_end: torch.Tensor,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
duration: int,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
|
||||||
|
pika_files = [
|
||||||
|
(
|
||||||
|
"keyFrames",
|
||||||
|
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
|
||||||
|
),
|
||||||
|
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||||
|
]
|
||||||
|
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_PIKAFRAMES,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||||
|
response_model=PikaGenerateResponse,
|
||||||
|
),
|
||||||
|
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
),
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.execute_task(initial_operation, auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"PikaImageToVideoNode2_2": PikaImageToVideoV2_2,
|
||||||
|
"PikaTextToVideoNode2_2": PikaTextToVideoNodeV2_2,
|
||||||
|
"PikaScenesV2_2": PikaScenesV2_2,
|
||||||
|
"Pikadditions": PikAdditionsNode,
|
||||||
|
"Pikaswaps": PikaSwapsNode,
|
||||||
|
"Pikaffects": PikaffectsNode,
|
||||||
|
"PikaStartEndFrameNode2_2": PikaStartEndFrameNode2_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"PikaImageToVideoNode2_2": "Pika Image to Video",
|
||||||
|
"PikaTextToVideoNode2_2": "Pika Text to Video",
|
||||||
|
"PikaScenesV2_2": "Pika Scenes (Video Image Composition)",
|
||||||
|
"Pikadditions": "Pikadditions (Video Object Insertion)",
|
||||||
|
"Pikaswaps": "Pika Swaps (Video Object Replacement)",
|
||||||
|
"Pikaffects": "Pikaffects (Video Effects)",
|
||||||
|
"PikaStartEndFrameNode2_2": "Pika Start and End Frame to Video",
|
||||||
|
}
|
||||||
492
comfy_api_nodes/nodes_pixverse.py
Normal file
492
comfy_api_nodes/nodes_pixverse.py
Normal file
@@ -0,0 +1,492 @@
|
|||||||
|
from inspect import cleandoc
|
||||||
|
|
||||||
|
from comfy_api_nodes.apis.pixverse_api import (
|
||||||
|
PixverseTextVideoRequest,
|
||||||
|
PixverseImageVideoRequest,
|
||||||
|
PixverseTransitionVideoRequest,
|
||||||
|
PixverseImageUploadResponse,
|
||||||
|
PixverseVideoResponse,
|
||||||
|
PixverseGenerationStatusResponse,
|
||||||
|
PixverseAspectRatio,
|
||||||
|
PixverseQuality,
|
||||||
|
PixverseDuration,
|
||||||
|
PixverseMotionMode,
|
||||||
|
PixverseStatus,
|
||||||
|
PixverseIO,
|
||||||
|
pixverse_templates,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
tensor_to_bytesio,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
def upload_image_to_pixverse(image: torch.Tensor, auth_token=None):
|
||||||
|
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||||
|
files = {
|
||||||
|
"image": tensor_to_bytesio(image)
|
||||||
|
}
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/pixverse/image/upload",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=PixverseImageUploadResponse,
|
||||||
|
),
|
||||||
|
request=EmptyRequest(),
|
||||||
|
files=files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_upload: PixverseImageUploadResponse = operation.execute()
|
||||||
|
|
||||||
|
if response_upload.Resp is None:
|
||||||
|
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||||
|
|
||||||
|
return response_upload.Resp.img_id
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseTemplateNode:
|
||||||
|
"""
|
||||||
|
Select template for PixVerse Video generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (PixverseIO.TEMPLATE,)
|
||||||
|
RETURN_NAMES = ("pixverse_template",)
|
||||||
|
FUNCTION = "create_template"
|
||||||
|
CATEGORY = "api node/video/PixVerse"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"template": (list(pixverse_templates.keys()), ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_template(self, template: str):
|
||||||
|
template_id = pixverse_templates.get(template, None)
|
||||||
|
if template_id is None:
|
||||||
|
raise Exception(f"Template '{template}' is not recognized.")
|
||||||
|
# just return the integer
|
||||||
|
return (template_id,)
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseTextToVideoNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on prompt and output_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.VIDEO,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/video/PixVerse"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"aspect_ratio": (
|
||||||
|
[ratio.value for ratio in PixverseAspectRatio],
|
||||||
|
),
|
||||||
|
"quality": (
|
||||||
|
[resolution.value for resolution in PixverseQuality],
|
||||||
|
{
|
||||||
|
"default": PixverseQuality.res_540p,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||||
|
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2147483647,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "Seed for video generation.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "",
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "An optional text description of undesired elements on an image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"pixverse_template": (
|
||||||
|
PixverseIO.TEMPLATE,
|
||||||
|
{
|
||||||
|
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
quality: str,
|
||||||
|
duration_seconds: int,
|
||||||
|
motion_mode: str,
|
||||||
|
seed,
|
||||||
|
negative_prompt: str=None,
|
||||||
|
pixverse_template: int=None,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
# 1080p is limited to 5 seconds duration
|
||||||
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
|
if quality == PixverseQuality.res_1080p:
|
||||||
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
duration_seconds = PixverseDuration.dur_5
|
||||||
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/pixverse/video/text/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PixverseTextVideoRequest,
|
||||||
|
response_model=PixverseVideoResponse,
|
||||||
|
),
|
||||||
|
request=PixverseTextVideoRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
quality=quality,
|
||||||
|
duration=duration_seconds,
|
||||||
|
motion_mode=motion_mode,
|
||||||
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
|
template_id=pixverse_template,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api = operation.execute()
|
||||||
|
|
||||||
|
if response_api.Resp is None:
|
||||||
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
|
operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=PixverseGenerationStatusResponse,
|
||||||
|
),
|
||||||
|
completed_statuses=[PixverseStatus.successful],
|
||||||
|
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||||
|
status_extractor=lambda x: x.Resp.status,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_poll = operation.execute()
|
||||||
|
|
||||||
|
vid_response = requests.get(response_poll.Resp.url)
|
||||||
|
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on prompt and output_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.VIDEO,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/video/PixVerse"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"quality": (
|
||||||
|
[resolution.value for resolution in PixverseQuality],
|
||||||
|
{
|
||||||
|
"default": PixverseQuality.res_540p,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||||
|
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2147483647,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "Seed for video generation.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "",
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "An optional text description of undesired elements on an image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"pixverse_template": (
|
||||||
|
PixverseIO.TEMPLATE,
|
||||||
|
{
|
||||||
|
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
quality: str,
|
||||||
|
duration_seconds: int,
|
||||||
|
motion_mode: str,
|
||||||
|
seed,
|
||||||
|
negative_prompt: str=None,
|
||||||
|
pixverse_template: int=None,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
img_id = upload_image_to_pixverse(image, auth_token=auth_token)
|
||||||
|
|
||||||
|
# 1080p is limited to 5 seconds duration
|
||||||
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
|
if quality == PixverseQuality.res_1080p:
|
||||||
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
duration_seconds = PixverseDuration.dur_5
|
||||||
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/pixverse/video/img/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PixverseImageVideoRequest,
|
||||||
|
response_model=PixverseVideoResponse,
|
||||||
|
),
|
||||||
|
request=PixverseImageVideoRequest(
|
||||||
|
img_id=img_id,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
duration=duration_seconds,
|
||||||
|
motion_mode=motion_mode,
|
||||||
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
|
template_id=pixverse_template,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api = operation.execute()
|
||||||
|
|
||||||
|
if response_api.Resp is None:
|
||||||
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
|
operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=PixverseGenerationStatusResponse,
|
||||||
|
),
|
||||||
|
completed_statuses=[PixverseStatus.successful],
|
||||||
|
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||||
|
status_extractor=lambda x: x.Resp.status,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_poll = operation.execute()
|
||||||
|
|
||||||
|
vid_response = requests.get(response_poll.Resp.url)
|
||||||
|
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||||
|
|
||||||
|
|
||||||
|
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on prompt and output_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.VIDEO,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/video/PixVerse"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"first_frame": (
|
||||||
|
IO.IMAGE,
|
||||||
|
),
|
||||||
|
"last_frame": (
|
||||||
|
IO.IMAGE,
|
||||||
|
),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the video generation",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"quality": (
|
||||||
|
[resolution.value for resolution in PixverseQuality],
|
||||||
|
{
|
||||||
|
"default": PixverseQuality.res_540p,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||||
|
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2147483647,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "Seed for video generation.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "",
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "An optional text description of undesired elements on an image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
first_frame: torch.Tensor,
|
||||||
|
last_frame: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
quality: str,
|
||||||
|
duration_seconds: int,
|
||||||
|
motion_mode: str,
|
||||||
|
seed,
|
||||||
|
negative_prompt: str=None,
|
||||||
|
auth_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
first_frame_id = upload_image_to_pixverse(first_frame, auth_token=auth_token)
|
||||||
|
last_frame_id = upload_image_to_pixverse(last_frame, auth_token=auth_token)
|
||||||
|
|
||||||
|
# 1080p is limited to 5 seconds duration
|
||||||
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
|
if quality == PixverseQuality.res_1080p:
|
||||||
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
duration_seconds = PixverseDuration.dur_5
|
||||||
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/pixverse/video/transition/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=PixverseTransitionVideoRequest,
|
||||||
|
response_model=PixverseVideoResponse,
|
||||||
|
),
|
||||||
|
request=PixverseTransitionVideoRequest(
|
||||||
|
first_frame_img=first_frame_id,
|
||||||
|
last_frame_img=last_frame_id,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
duration=duration_seconds,
|
||||||
|
motion_mode=motion_mode,
|
||||||
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api = operation.execute()
|
||||||
|
|
||||||
|
if response_api.Resp is None:
|
||||||
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
|
operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=PixverseGenerationStatusResponse,
|
||||||
|
),
|
||||||
|
completed_statuses=[PixverseStatus.successful],
|
||||||
|
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||||
|
status_extractor=lambda x: x.Resp.status,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_poll = operation.execute()
|
||||||
|
|
||||||
|
vid_response = requests.get(response_poll.Resp.url)
|
||||||
|
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"PixverseTextToVideoNode": PixverseTextToVideoNode,
|
||||||
|
"PixverseImageToVideoNode": PixverseImageToVideoNode,
|
||||||
|
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
|
||||||
|
"PixverseTemplateNode": PixverseTemplateNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"PixverseTextToVideoNode": "PixVerse Text to Video",
|
||||||
|
"PixverseImageToVideoNode": "PixVerse Image to Video",
|
||||||
|
"PixverseTransitionVideoNode": "PixVerse Transition Video",
|
||||||
|
"PixverseTemplateNode": "PixVerse Template",
|
||||||
|
}
|
||||||
1217
comfy_api_nodes/nodes_recraft.py
Normal file
1217
comfy_api_nodes/nodes_recraft.py
Normal file
File diff suppressed because it is too large
Load Diff
609
comfy_api_nodes/nodes_stability.py
Normal file
609
comfy_api_nodes/nodes_stability.py
Normal file
@@ -0,0 +1,609 @@
|
|||||||
|
from inspect import cleandoc
|
||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
from comfy_api_nodes.apis.stability_api import (
|
||||||
|
StabilityUpscaleConservativeRequest,
|
||||||
|
StabilityUpscaleCreativeRequest,
|
||||||
|
StabilityAsyncResponse,
|
||||||
|
StabilityResultsGetResponse,
|
||||||
|
StabilityStable3_5Request,
|
||||||
|
StabilityStableUltraRequest,
|
||||||
|
StabilityStableUltraResponse,
|
||||||
|
StabilityAspectRatio,
|
||||||
|
Stability_SD3_5_Model,
|
||||||
|
Stability_SD3_5_GenerationMode,
|
||||||
|
get_stability_style_presets,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
bytesio_to_image_tensor,
|
||||||
|
tensor_to_bytesio,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityPollStatus(str, Enum):
|
||||||
|
finished = "finished"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_dummy_status(x: StabilityResultsGetResponse):
|
||||||
|
if x.name is not None or x.errors is not None:
|
||||||
|
return StabilityPollStatus.failed
|
||||||
|
elif x.finish_reason is not None:
|
||||||
|
return StabilityPollStatus.finished
|
||||||
|
return StabilityPollStatus.in_progress
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityStableImageUltraNode:
|
||||||
|
"""
|
||||||
|
Generates images synchronously based on prompt and resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/Stability AI"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||||
|
"What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||||
|
"elements, colors, and subjects will lead to better results. " +
|
||||||
|
"To control the weight of a given word use the format `(word:weight)`," +
|
||||||
|
"where `word` is the word you'd like to control the weight of and `weight`" +
|
||||||
|
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||||
|
"would convey a sky that was blue and green, but more green than blue."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
|
||||||
|
{
|
||||||
|
"default": StabilityAspectRatio.ratio_1_1,
|
||||||
|
"tooltip": "Aspect ratio of generated image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"style_preset": (get_stability_style_presets(),
|
||||||
|
{
|
||||||
|
"tooltip": "Optional desired style of generated image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 4294967294,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "",
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"image_denoise": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.5,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||||
|
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||||
|
auth_token=None):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
# prepare image binary if image present
|
||||||
|
image_binary = None
|
||||||
|
if image is not None:
|
||||||
|
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||||
|
else:
|
||||||
|
image_denoise = None
|
||||||
|
|
||||||
|
if not negative_prompt:
|
||||||
|
negative_prompt = None
|
||||||
|
if style_preset == "None":
|
||||||
|
style_preset = None
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"image": image_binary
|
||||||
|
}
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=StabilityStableUltraRequest,
|
||||||
|
response_model=StabilityStableUltraResponse,
|
||||||
|
),
|
||||||
|
request=StabilityStableUltraRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
seed=seed,
|
||||||
|
strength=image_denoise,
|
||||||
|
style_preset=style_preset,
|
||||||
|
),
|
||||||
|
files=files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api = operation.execute()
|
||||||
|
|
||||||
|
if response_api.finish_reason != "SUCCESS":
|
||||||
|
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||||
|
|
||||||
|
image_data = base64.b64decode(response_api.image)
|
||||||
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
|
return (returned_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityStableImageSD_3_5Node:
|
||||||
|
"""
|
||||||
|
Generates images synchronously based on prompt and resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/Stability AI"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": ([x.value for x in Stability_SD3_5_Model],),
|
||||||
|
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
|
||||||
|
{
|
||||||
|
"default": StabilityAspectRatio.ratio_1_1,
|
||||||
|
"tooltip": "Aspect ratio of generated image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"style_preset": (get_stability_style_presets(),
|
||||||
|
{
|
||||||
|
"tooltip": "Optional desired style of generated image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"cfg_scale": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 4.0,
|
||||||
|
"min": 1.0,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.1,
|
||||||
|
"tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 4294967294,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "",
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"image_denoise": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.5,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||||
|
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||||
|
auth_token=None):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
# prepare image binary if image present
|
||||||
|
image_binary = None
|
||||||
|
mode = Stability_SD3_5_GenerationMode.text_to_image
|
||||||
|
if image is not None:
|
||||||
|
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||||
|
mode = Stability_SD3_5_GenerationMode.image_to_image
|
||||||
|
aspect_ratio = None
|
||||||
|
else:
|
||||||
|
image_denoise = None
|
||||||
|
|
||||||
|
if not negative_prompt:
|
||||||
|
negative_prompt = None
|
||||||
|
if style_preset == "None":
|
||||||
|
style_preset = None
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"image": image_binary
|
||||||
|
}
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=StabilityStable3_5Request,
|
||||||
|
response_model=StabilityStableUltraResponse,
|
||||||
|
),
|
||||||
|
request=StabilityStable3_5Request(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
seed=seed,
|
||||||
|
strength=image_denoise,
|
||||||
|
style_preset=style_preset,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
model=model,
|
||||||
|
mode=mode,
|
||||||
|
),
|
||||||
|
files=files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api = operation.execute()
|
||||||
|
|
||||||
|
if response_api.finish_reason != "SUCCESS":
|
||||||
|
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||||
|
|
||||||
|
image_data = base64.b64decode(response_api.image)
|
||||||
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
|
return (returned_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityUpscaleConservativeNode:
|
||||||
|
"""
|
||||||
|
Upscale image with minimal alterations to 4K resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/Stability AI"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"creativity": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.35,
|
||||||
|
"min": 0.2,
|
||||||
|
"max": 0.5,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 4294967294,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "",
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||||
|
auth_token=None):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||||
|
|
||||||
|
if not negative_prompt:
|
||||||
|
negative_prompt = None
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"image": image_binary
|
||||||
|
}
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=StabilityUpscaleConservativeRequest,
|
||||||
|
response_model=StabilityStableUltraResponse,
|
||||||
|
),
|
||||||
|
request=StabilityUpscaleConservativeRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
creativity=round(creativity,2),
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
files=files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api = operation.execute()
|
||||||
|
|
||||||
|
if response_api.finish_reason != "SUCCESS":
|
||||||
|
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||||
|
|
||||||
|
image_data = base64.b64decode(response_api.image)
|
||||||
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
|
return (returned_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityUpscaleCreativeNode:
|
||||||
|
"""
|
||||||
|
Upscale image with minimal alterations to 4K resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/Stability AI"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"creativity": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.3,
|
||||||
|
"min": 0.1,
|
||||||
|
"max": 0.5,
|
||||||
|
"step": 0.01,
|
||||||
|
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"style_preset": (get_stability_style_presets(),
|
||||||
|
{
|
||||||
|
"tooltip": "Optional desired style of generated image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 4294967294,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "",
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||||
|
auth_token=None):
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||||
|
|
||||||
|
if not negative_prompt:
|
||||||
|
negative_prompt = None
|
||||||
|
if style_preset == "None":
|
||||||
|
style_preset = None
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"image": image_binary
|
||||||
|
}
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=StabilityUpscaleCreativeRequest,
|
||||||
|
response_model=StabilityAsyncResponse,
|
||||||
|
),
|
||||||
|
request=StabilityUpscaleCreativeRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
creativity=round(creativity,2),
|
||||||
|
style_preset=style_preset,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
files=files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api = operation.execute()
|
||||||
|
|
||||||
|
operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=StabilityResultsGetResponse,
|
||||||
|
),
|
||||||
|
poll_interval=3,
|
||||||
|
completed_statuses=[StabilityPollStatus.finished],
|
||||||
|
failed_statuses=[StabilityPollStatus.failed],
|
||||||
|
status_extractor=lambda x: get_async_dummy_status(x),
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_poll: StabilityResultsGetResponse = operation.execute()
|
||||||
|
|
||||||
|
if response_poll.finish_reason != "SUCCESS":
|
||||||
|
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||||
|
|
||||||
|
image_data = base64.b64decode(response_poll.result)
|
||||||
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
|
return (returned_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityUpscaleFastNode:
|
||||||
|
"""
|
||||||
|
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/Stability AI"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(self, image: torch.Tensor,
|
||||||
|
auth_token=None):
|
||||||
|
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"image": image_binary
|
||||||
|
}
|
||||||
|
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=StabilityStableUltraResponse,
|
||||||
|
),
|
||||||
|
request=EmptyRequest(),
|
||||||
|
files=files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
response_api = operation.execute()
|
||||||
|
|
||||||
|
if response_api.finish_reason != "SUCCESS":
|
||||||
|
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||||
|
|
||||||
|
image_data = base64.b64decode(response_api.image)
|
||||||
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
|
return (returned_image,)
|
||||||
|
|
||||||
|
|
||||||
|
# A dictionary that contains all nodes you want to export with their names
|
||||||
|
# NOTE: names should be globally unique
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"StabilityStableImageUltraNode": StabilityStableImageUltraNode,
|
||||||
|
"StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node,
|
||||||
|
"StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode,
|
||||||
|
"StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode,
|
||||||
|
"StabilityUpscaleFastNode": StabilityUpscaleFastNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"StabilityStableImageUltraNode": "Stability AI Stable Image Ultra",
|
||||||
|
"StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image",
|
||||||
|
"StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative",
|
||||||
|
"StabilityUpscaleCreativeNode": "Stability AI Upscale Creative",
|
||||||
|
"StabilityUpscaleFastNode": "Stability AI Upscale Fast",
|
||||||
|
}
|
||||||
283
comfy_api_nodes/nodes_veo2.py
Normal file
283
comfy_api_nodes/nodes_veo2.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||||
|
from comfy_api.input_impl.video_types import VideoFromFile
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
Veo2GenVidRequest,
|
||||||
|
Veo2GenVidResponse,
|
||||||
|
Veo2GenVidPollRequest,
|
||||||
|
Veo2GenVidPollResponse
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
)
|
||||||
|
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
downscale_image_tensor,
|
||||||
|
tensor_to_base64_string
|
||||||
|
)
|
||||||
|
|
||||||
|
def convert_image_to_base64(image: torch.Tensor):
|
||||||
|
if image is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
||||||
|
return tensor_to_base64_string(scaled_image)
|
||||||
|
|
||||||
|
class VeoVideoGenerationNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates videos from text prompts using Google's Veo API.
|
||||||
|
|
||||||
|
This node can create videos from text descriptions and optional image inputs,
|
||||||
|
with control over parameters like aspect ratio, duration, and more.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text description of the video",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"aspect_ratio": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["16:9", "9:16"],
|
||||||
|
"default": "16:9",
|
||||||
|
"tooltip": "Aspect ratio of the output video",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"negative_prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Negative text prompt to guide what to avoid in the video",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"duration_seconds": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 5,
|
||||||
|
"min": 5,
|
||||||
|
"max": 8,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"tooltip": "Duration of the output video in seconds",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"enhance_prompt": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Whether to enhance the prompt with AI assistance",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"person_generation": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["ALLOW", "BLOCK"],
|
||||||
|
"default": "ALLOW",
|
||||||
|
"tooltip": "Whether to allow generating people in the video",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFF,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "Seed for video generation (0 for random)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"image": (IO.IMAGE, {
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional reference image to guide video generation",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.VIDEO,)
|
||||||
|
FUNCTION = "generate_video"
|
||||||
|
CATEGORY = "api node/video/Veo"
|
||||||
|
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def generate_video(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
aspect_ratio="16:9",
|
||||||
|
negative_prompt="",
|
||||||
|
duration_seconds=5,
|
||||||
|
enhance_prompt=True,
|
||||||
|
person_generation="ALLOW",
|
||||||
|
seed=0,
|
||||||
|
image=None,
|
||||||
|
auth_token=None,
|
||||||
|
):
|
||||||
|
# Prepare the instances for the request
|
||||||
|
instances = []
|
||||||
|
|
||||||
|
instance = {
|
||||||
|
"prompt": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add image if provided
|
||||||
|
if image is not None:
|
||||||
|
image_base64 = convert_image_to_base64(image)
|
||||||
|
if image_base64:
|
||||||
|
instance["image"] = {
|
||||||
|
"bytesBase64Encoded": image_base64,
|
||||||
|
"mimeType": "image/png"
|
||||||
|
}
|
||||||
|
|
||||||
|
instances.append(instance)
|
||||||
|
|
||||||
|
# Create parameters dictionary
|
||||||
|
parameters = {
|
||||||
|
"aspectRatio": aspect_ratio,
|
||||||
|
"personGeneration": person_generation,
|
||||||
|
"durationSeconds": duration_seconds,
|
||||||
|
"enhancePrompt": enhance_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters if provided
|
||||||
|
if negative_prompt:
|
||||||
|
parameters["negativePrompt"] = negative_prompt
|
||||||
|
if seed > 0:
|
||||||
|
parameters["seed"] = seed
|
||||||
|
|
||||||
|
# Initial request to start video generation
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/veo/generate",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=Veo2GenVidRequest,
|
||||||
|
response_model=Veo2GenVidResponse
|
||||||
|
),
|
||||||
|
request=Veo2GenVidRequest(
|
||||||
|
instances=instances,
|
||||||
|
parameters=parameters
|
||||||
|
),
|
||||||
|
auth_token=auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_response = initial_operation.execute()
|
||||||
|
operation_name = initial_response.name
|
||||||
|
|
||||||
|
logging.info(f"Veo generation started with operation name: {operation_name}")
|
||||||
|
|
||||||
|
# Define status extractor function
|
||||||
|
def status_extractor(response):
|
||||||
|
# Only return "completed" if the operation is done, regardless of success or failure
|
||||||
|
# We'll check for errors after polling completes
|
||||||
|
return "completed" if response.done else "pending"
|
||||||
|
|
||||||
|
# Define progress extractor function
|
||||||
|
def progress_extractor(response):
|
||||||
|
# Could be enhanced if the API provides progress information
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Define the polling operation
|
||||||
|
poll_operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/veo/poll",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=Veo2GenVidPollRequest,
|
||||||
|
response_model=Veo2GenVidPollResponse
|
||||||
|
),
|
||||||
|
completed_statuses=["completed"],
|
||||||
|
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||||
|
status_extractor=status_extractor,
|
||||||
|
progress_extractor=progress_extractor,
|
||||||
|
request=Veo2GenVidPollRequest(
|
||||||
|
operationName=operation_name
|
||||||
|
),
|
||||||
|
auth_token=auth_token,
|
||||||
|
poll_interval=5.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the polling operation
|
||||||
|
poll_response = poll_operation.execute()
|
||||||
|
|
||||||
|
# Now check for errors in the final response
|
||||||
|
# Check for error in poll response
|
||||||
|
if hasattr(poll_response, 'error') and poll_response.error:
|
||||||
|
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
|
||||||
|
logging.error(error_message)
|
||||||
|
raise Exception(error_message)
|
||||||
|
|
||||||
|
# Check for RAI filtered content
|
||||||
|
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
|
||||||
|
poll_response.response.raiMediaFilteredCount > 0):
|
||||||
|
|
||||||
|
# Extract reason message if available
|
||||||
|
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
|
||||||
|
poll_response.response.raiMediaFilteredReasons):
|
||||||
|
reason = poll_response.response.raiMediaFilteredReasons[0]
|
||||||
|
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||||
|
else:
|
||||||
|
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||||
|
|
||||||
|
logging.error(error_message)
|
||||||
|
raise Exception(error_message)
|
||||||
|
|
||||||
|
# Extract video data
|
||||||
|
video_data = None
|
||||||
|
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||||
|
video = poll_response.response.videos[0]
|
||||||
|
|
||||||
|
# Check if video is provided as base64 or URL
|
||||||
|
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
|
||||||
|
# Decode base64 string to bytes
|
||||||
|
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||||
|
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||||
|
# Download from URL
|
||||||
|
video_url = video.gcsUri
|
||||||
|
video_response = requests.get(video_url)
|
||||||
|
video_data = video_response.content
|
||||||
|
else:
|
||||||
|
raise Exception("Video returned but no data or URL was provided")
|
||||||
|
else:
|
||||||
|
raise Exception("Video generation completed but no video was returned")
|
||||||
|
|
||||||
|
if not video_data:
|
||||||
|
raise Exception("No video data was returned")
|
||||||
|
|
||||||
|
logging.info("Video generation completed successfully")
|
||||||
|
|
||||||
|
# Convert video data to BytesIO object
|
||||||
|
video_io = io.BytesIO(video_data)
|
||||||
|
|
||||||
|
# Return VideoFromFile object
|
||||||
|
return (VideoFromFile(video_io),)
|
||||||
|
|
||||||
|
|
||||||
|
# Register the node
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
|
||||||
|
}
|
||||||
10
comfy_api_nodes/redocly-dev.yaml
Normal file
10
comfy_api_nodes/redocly-dev.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
|
||||||
|
# This is used for development purposes to generate stubs for unreleased API endpoints.
|
||||||
|
apis:
|
||||||
|
filter:
|
||||||
|
root: openapi.yaml
|
||||||
|
decorators:
|
||||||
|
filter-in:
|
||||||
|
property: tags
|
||||||
|
value: ['API Nodes']
|
||||||
|
matchStrategy: all
|
||||||
10
comfy_api_nodes/redocly.yaml
Normal file
10
comfy_api_nodes/redocly.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
|
||||||
|
|
||||||
|
apis:
|
||||||
|
filter:
|
||||||
|
root: openapi.yaml
|
||||||
|
decorators:
|
||||||
|
filter-in:
|
||||||
|
property: tags
|
||||||
|
value: ['API Nodes', 'Released']
|
||||||
|
matchStrategy: all
|
||||||
@@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class T5TokenizerOptions:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
||||||
|
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "set_options"
|
||||||
|
|
||||||
|
def set_options(self, clip, min_padding, min_length):
|
||||||
|
clip = clip.clone()
|
||||||
|
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
|
||||||
|
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
|
||||||
|
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
|
||||||
|
|
||||||
|
return (clip, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
|
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
|
||||||
|
"T5TokenizerOptions": T5TokenizerOptions,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.sample
|
import comfy.sample
|
||||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||||
@@ -249,6 +250,55 @@ class SetFirstSigma:
|
|||||||
sigmas[0] = sigma
|
sigmas[0] = sigma
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
class ExtendIntermediateSigmas:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"sigmas": ("SIGMAS", ),
|
||||||
|
"steps": ("INT", {"default": 2, "min": 1, "max": 100}),
|
||||||
|
"start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}),
|
||||||
|
"end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}),
|
||||||
|
"spacing": (['linear', 'cosine', 'sine'],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||||
|
|
||||||
|
FUNCTION = "extend"
|
||||||
|
|
||||||
|
def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str):
|
||||||
|
if start_at_sigma < 0:
|
||||||
|
start_at_sigma = float("inf")
|
||||||
|
|
||||||
|
interpolator = {
|
||||||
|
'linear': lambda x: x,
|
||||||
|
'cosine': lambda x: torch.sin(x*math.pi/2),
|
||||||
|
'sine': lambda x: 1 - torch.cos(x*math.pi/2)
|
||||||
|
}[spacing]
|
||||||
|
|
||||||
|
# linear space for our interpolation function
|
||||||
|
x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
|
||||||
|
computed_spacing = interpolator(x)
|
||||||
|
|
||||||
|
extended_sigmas = []
|
||||||
|
for i in range(len(sigmas) - 1):
|
||||||
|
sigma_current = sigmas[i]
|
||||||
|
sigma_next = sigmas[i+1]
|
||||||
|
|
||||||
|
extended_sigmas.append(sigma_current)
|
||||||
|
|
||||||
|
if end_at_sigma <= sigma_current <= start_at_sigma:
|
||||||
|
interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
|
||||||
|
extended_sigmas.extend(interpolated_steps.tolist())
|
||||||
|
|
||||||
|
# Add the last sigma value
|
||||||
|
if len(sigmas) > 0:
|
||||||
|
extended_sigmas.append(sigmas[-1])
|
||||||
|
|
||||||
|
extended_sigmas = torch.FloatTensor(extended_sigmas)
|
||||||
|
|
||||||
|
return (extended_sigmas,)
|
||||||
|
|
||||||
class KSamplerSelect:
|
class KSamplerSelect:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -735,6 +785,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||||
"FlipSigmas": FlipSigmas,
|
"FlipSigmas": FlipSigmas,
|
||||||
"SetFirstSigma": SetFirstSigma,
|
"SetFirstSigma": SetFirstSigma,
|
||||||
|
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
|
||||||
|
|
||||||
"CFGGuider": CFGGuider,
|
"CFGGuider": CFGGuider,
|
||||||
"DualCFGGuider": DualCFGGuider,
|
"DualCFGGuider": DualCFGGuider,
|
||||||
|
|||||||
@@ -26,7 +26,30 @@ class QuadrupleCLIPLoader:
|
|||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
|
class CLIPTextEncodeHiDream:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
|
||||||
|
|
||||||
|
tokens = clip.tokenize(clip_g)
|
||||||
|
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||||
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||||
|
tokens["llama"] = clip.tokenize(llama)["llama"]
|
||||||
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||||
|
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class LTXVImgToVideo:
|
|||||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
@@ -46,7 +47,7 @@ class LTXVImgToVideo:
|
|||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
|
||||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
encode_pixels = pixels[:, :, :, :3]
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
@@ -59,7 +60,7 @@ class LTXVImgToVideo:
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=latent.device,
|
device=latent.device,
|
||||||
)
|
)
|
||||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
|
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||||
|
|
||||||
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
||||||
|
|
||||||
@@ -152,6 +153,15 @@ class LTXVAddGuide:
|
|||||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||||
|
|
||||||
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||||
|
_, latent_idx = self.get_latent_index(
|
||||||
|
cond=positive,
|
||||||
|
latent_length=latent_image.shape[2],
|
||||||
|
guide_length=guiding_latent.shape[2],
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
scale_factors=scale_factors,
|
||||||
|
)
|
||||||
|
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
||||||
|
|
||||||
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||||
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||||
|
|
||||||
@@ -385,7 +395,7 @@ def encode_single_frame(output_file, image_array: np.ndarray, crf):
|
|||||||
container = av.open(output_file, "w", format="mp4")
|
container = av.open(output_file, "w", format="mp4")
|
||||||
try:
|
try:
|
||||||
stream = container.add_stream(
|
stream = container.add_stream(
|
||||||
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||||
)
|
)
|
||||||
stream.height = image_array.shape[0]
|
stream.height = image_array.shape[0]
|
||||||
stream.width = image_array.shape[1]
|
stream.width = image_array.shape[1]
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ import scipy.ndimage
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
import folder_paths
|
||||||
|
import random
|
||||||
|
|
||||||
|
import nodes
|
||||||
from nodes import MAX_RESOLUTION
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||||
@@ -362,6 +365,30 @@ class ThresholdMask:
|
|||||||
mask = (mask > value).float()
|
mask = (mask > value).float()
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|
||||||
|
# Mask Preview - original implement from
|
||||||
|
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||||
|
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||||
|
class MaskPreview(nodes.SaveImage):
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_temp_directory()
|
||||||
|
self.type = "temp"
|
||||||
|
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||||
|
self.compress_level = 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {"mask": ("MASK",), },
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "mask"
|
||||||
|
|
||||||
|
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
|
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentCompositeMasked": LatentCompositeMasked,
|
"LatentCompositeMasked": LatentCompositeMasked,
|
||||||
@@ -376,6 +403,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"FeatherMask": FeatherMask,
|
"FeatherMask": FeatherMask,
|
||||||
"GrowMask": GrowMask,
|
"GrowMask": GrowMask,
|
||||||
"ThresholdMask": ThresholdMask,
|
"ThresholdMask": ThresholdMask,
|
||||||
|
"MaskPreview": MaskPreview
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
|||||||
@@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|||||||
metadata["modelspec.predict_key"] = "epsilon"
|
metadata["modelspec.predict_key"] = "epsilon"
|
||||||
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||||
metadata["modelspec.predict_key"] = "v"
|
metadata["modelspec.predict_key"] = "v"
|
||||||
|
extra_keys["v_pred"] = torch.tensor([])
|
||||||
|
if getattr(model_sampling, "zsnr", False):
|
||||||
|
extra_keys["ztsnr"] = torch.tensor([])
|
||||||
|
|
||||||
if not args.disable_metadata:
|
if not args.disable_metadata:
|
||||||
metadata["prompt"] = prompt_info
|
metadata["prompt"] = prompt_info
|
||||||
@@ -273,7 +276,7 @@ class CLIPSave:
|
|||||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.get_sd()
|
||||||
|
|
||||||
for prefix in ["clip_l.", "clip_g.", ""]:
|
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
||||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||||
current_clip_sd = {}
|
current_clip_sd = {}
|
||||||
for x in k:
|
for x in k:
|
||||||
|
|||||||
@@ -20,13 +20,14 @@ def loglinear_interp(t_steps, num_steps):
|
|||||||
|
|
||||||
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
|
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
|
||||||
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
|
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
|
||||||
|
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
|
||||||
}
|
}
|
||||||
|
|
||||||
class OptimalStepsScheduler:
|
class OptimalStepsScheduler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"model_type": (["FLUX", "Wan"], ),
|
{"model_type": (["FLUX", "Wan", "Chroma"], ),
|
||||||
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
|
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
|
||||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ class Quantize:
|
|||||||
|
|
||||||
CATEGORY = "image/postprocessing"
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def bayer(im, pal_im, order):
|
def bayer(im, pal_im, order):
|
||||||
def normalized_bayer_matrix(n):
|
def normalized_bayer_matrix(n):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
|
|||||||
43
comfy_extras/nodes_preview_any.py
Normal file
43
comfy_extras/nodes_preview_any.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import json
|
||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
|
||||||
|
# Preview Any - original implement from
|
||||||
|
# https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py
|
||||||
|
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||||
|
class PreviewAny():
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {"source": (IO.ANY, {})},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "main"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "utils"
|
||||||
|
|
||||||
|
def main(self, source=None):
|
||||||
|
value = 'None'
|
||||||
|
if isinstance(source, str):
|
||||||
|
value = source
|
||||||
|
elif isinstance(source, (int, float, bool)):
|
||||||
|
value = str(source)
|
||||||
|
elif source is not None:
|
||||||
|
try:
|
||||||
|
value = json.dumps(source)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
value = str(source)
|
||||||
|
except Exception:
|
||||||
|
value = 'source exists, but could not be serialized.'
|
||||||
|
|
||||||
|
return {"ui": {"text": (value,)}}
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"PreviewAny": PreviewAny,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"PreviewAny": "Preview Any",
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
# Primitive nodes that are evaluated at backend.
|
# Primitive nodes that are evaluated at backend.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
||||||
|
|
||||||
|
|
||||||
@@ -19,11 +21,26 @@ class String(ComfyNodeABC):
|
|||||||
return (value,)
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
|
class StringMultiline(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {"value": (IO.STRING, {"multiline": True,},)},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.STRING,)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "utils/primitive"
|
||||||
|
|
||||||
|
def execute(self, value: str) -> tuple[str]:
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
class Int(ComfyNodeABC):
|
class Int(ComfyNodeABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
return {
|
return {
|
||||||
"required": {"value": (IO.INT, {"control_after_generate": True})},
|
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT,)
|
RETURN_TYPES = (IO.INT,)
|
||||||
@@ -38,7 +55,7 @@ class Float(ComfyNodeABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
return {
|
return {
|
||||||
"required": {"value": (IO.FLOAT, {})},
|
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = (IO.FLOAT,)
|
RETURN_TYPES = (IO.FLOAT,)
|
||||||
@@ -66,6 +83,7 @@ class Boolean(ComfyNodeABC):
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"PrimitiveString": String,
|
"PrimitiveString": String,
|
||||||
|
"PrimitiveStringMultiline": StringMultiline,
|
||||||
"PrimitiveInt": Int,
|
"PrimitiveInt": Int,
|
||||||
"PrimitiveFloat": Float,
|
"PrimitiveFloat": Float,
|
||||||
"PrimitiveBoolean": Boolean,
|
"PrimitiveBoolean": Boolean,
|
||||||
@@ -73,6 +91,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"PrimitiveString": "String",
|
"PrimitiveString": "String",
|
||||||
|
"PrimitiveStringMultiline": "String (Multiline)",
|
||||||
"PrimitiveInt": "Int",
|
"PrimitiveInt": "Int",
|
||||||
"PrimitiveFloat": "Float",
|
"PrimitiveFloat": "Float",
|
||||||
"PrimitiveBoolean": "Boolean",
|
"PrimitiveBoolean": "Boolean",
|
||||||
|
|||||||
@@ -5,9 +5,13 @@ import av
|
|||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import json
|
import json
|
||||||
|
from typing import Optional, Literal
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from comfy.comfy_types import FileLocator
|
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||||
|
from comfy_api.input import ImageInput, AudioInput, VideoInput
|
||||||
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
class SaveWEBM:
|
class SaveWEBM:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -50,13 +54,15 @@ class SaveWEBM:
|
|||||||
for x in extra_pnginfo:
|
for x in extra_pnginfo:
|
||||||
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"}
|
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
||||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||||
stream.width = images.shape[-2]
|
stream.width = images.shape[-2]
|
||||||
stream.height = images.shape[-3]
|
stream.height = images.shape[-3]
|
||||||
stream.pix_fmt = "yuv420p"
|
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
|
||||||
stream.bit_rate = 0
|
stream.bit_rate = 0
|
||||||
stream.options = {'crf': str(crf)}
|
stream.options = {'crf': str(crf)}
|
||||||
|
if codec == "av1":
|
||||||
|
stream.options["preset"] = "6"
|
||||||
|
|
||||||
for frame in images:
|
for frame in images:
|
||||||
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
||||||
@@ -73,7 +79,163 @@ class SaveWEBM:
|
|||||||
|
|
||||||
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
|
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
|
||||||
|
|
||||||
|
class SaveVideo(ComfyNodeABC):
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
self.type: Literal["output"] = "output"
|
||||||
|
self.prefix_append = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
||||||
|
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
||||||
|
"format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||||
|
"codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"prompt": "PROMPT",
|
||||||
|
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save_video"
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "image/video"
|
||||||
|
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||||
|
|
||||||
|
def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||||
|
filename_prefix += self.prefix_append
|
||||||
|
width, height = video.get_dimensions()
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix,
|
||||||
|
self.output_dir,
|
||||||
|
width,
|
||||||
|
height
|
||||||
|
)
|
||||||
|
results: list[FileLocator] = list()
|
||||||
|
saved_metadata = None
|
||||||
|
if not args.disable_metadata:
|
||||||
|
metadata = {}
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
metadata.update(extra_pnginfo)
|
||||||
|
if prompt is not None:
|
||||||
|
metadata["prompt"] = prompt
|
||||||
|
if len(metadata) > 0:
|
||||||
|
saved_metadata = metadata
|
||||||
|
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||||
|
video.save_to(
|
||||||
|
os.path.join(full_output_folder, file),
|
||||||
|
format=format,
|
||||||
|
codec=codec,
|
||||||
|
metadata=saved_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"filename": file,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": self.type
|
||||||
|
})
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return { "ui": { "images": results, "animated": (True,) } }
|
||||||
|
|
||||||
|
class CreateVideo(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
|
||||||
|
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.VIDEO,)
|
||||||
|
FUNCTION = "create_video"
|
||||||
|
|
||||||
|
CATEGORY = "image/video"
|
||||||
|
DESCRIPTION = "Create a video from images."
|
||||||
|
|
||||||
|
def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None):
|
||||||
|
return (VideoFromComponents(
|
||||||
|
VideoComponents(
|
||||||
|
images=images,
|
||||||
|
audio=audio,
|
||||||
|
frame_rate=Fraction(fps),
|
||||||
|
)
|
||||||
|
),)
|
||||||
|
|
||||||
|
class GetVideoComponents(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
|
||||||
|
RETURN_NAMES = ("images", "audio", "fps")
|
||||||
|
FUNCTION = "get_components"
|
||||||
|
|
||||||
|
CATEGORY = "image/video"
|
||||||
|
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
||||||
|
|
||||||
|
def get_components(self, video: VideoInput):
|
||||||
|
components = video.get_components()
|
||||||
|
|
||||||
|
return (components.images, components.audio, float(components.frame_rate))
|
||||||
|
|
||||||
|
class LoadVideo(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
|
files = folder_paths.filter_files_content_types(files, ["video"])
|
||||||
|
return {"required":
|
||||||
|
{"file": (sorted(files), {"video_upload": True})},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "image/video"
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.VIDEO,)
|
||||||
|
FUNCTION = "load_video"
|
||||||
|
def load_video(self, file):
|
||||||
|
video_path = folder_paths.get_annotated_filepath(file)
|
||||||
|
return (VideoFromFile(video_path),)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, file):
|
||||||
|
video_path = folder_paths.get_annotated_filepath(file)
|
||||||
|
mod_time = os.path.getmtime(video_path)
|
||||||
|
# Instead of hashing the file, we can just use the modification time to avoid
|
||||||
|
# rehashing large files.
|
||||||
|
return mod_time
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, file):
|
||||||
|
if not folder_paths.exists_annotated_filepath(file):
|
||||||
|
return "Invalid video file: {}".format(file)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SaveWEBM": SaveWEBM,
|
"SaveWEBM": SaveWEBM,
|
||||||
|
"SaveVideo": SaveVideo,
|
||||||
|
"CreateVideo": CreateVideo,
|
||||||
|
"GetVideoComponents": GetVideoComponents,
|
||||||
|
"LoadVideo": LoadVideo,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"SaveVideo": "Save Video",
|
||||||
|
"CreateVideo": "Create Video",
|
||||||
|
"GetVideoComponents": "Get Video Components",
|
||||||
|
"LoadVideo": "Load Video",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -193,9 +193,116 @@ class WanFunInpaintToVideo:
|
|||||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||||
|
|
||||||
|
|
||||||
|
class WanVaceToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {"control_video": ("IMAGE", ),
|
||||||
|
"control_masks": ("MASK", ),
|
||||||
|
"reference_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
|
||||||
|
latent_length = ((length - 1) // 4) + 1
|
||||||
|
if control_video is not None:
|
||||||
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if control_video.shape[0] < length:
|
||||||
|
control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5)
|
||||||
|
else:
|
||||||
|
control_video = torch.ones((length, height, width, 3)) * 0.5
|
||||||
|
|
||||||
|
if reference_image is not None:
|
||||||
|
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
reference_image = vae.encode(reference_image[:, :, :, :3])
|
||||||
|
reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1)
|
||||||
|
|
||||||
|
if control_masks is None:
|
||||||
|
mask = torch.ones((length, height, width, 1))
|
||||||
|
else:
|
||||||
|
mask = control_masks
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if mask.shape[0] < length:
|
||||||
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0)
|
||||||
|
|
||||||
|
control_video = control_video - 0.5
|
||||||
|
inactive = (control_video * (1 - mask)) + 0.5
|
||||||
|
reactive = (control_video * mask) + 0.5
|
||||||
|
|
||||||
|
inactive = vae.encode(inactive[:, :, :, :3])
|
||||||
|
reactive = vae.encode(reactive[:, :, :, :3])
|
||||||
|
control_video_latent = torch.cat((inactive, reactive), dim=1)
|
||||||
|
if reference_image is not None:
|
||||||
|
control_video_latent = torch.cat((reference_image, control_video_latent), dim=2)
|
||||||
|
|
||||||
|
vae_stride = 8
|
||||||
|
height_mask = height // vae_stride
|
||||||
|
width_mask = width // vae_stride
|
||||||
|
mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride)
|
||||||
|
mask = mask.permute(2, 4, 0, 1, 3)
|
||||||
|
mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask)
|
||||||
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0)
|
||||||
|
|
||||||
|
trim_latent = 0
|
||||||
|
if reference_image is not None:
|
||||||
|
mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :])
|
||||||
|
mask = torch.cat((mask_pad, mask), dim=1)
|
||||||
|
latent_length += reference_image.shape[2]
|
||||||
|
trim_latent = reference_image.shape[2]
|
||||||
|
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||||
|
|
||||||
|
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent, trim_latent)
|
||||||
|
|
||||||
|
class TrimVideoLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, samples, trim_amount):
|
||||||
|
samples_out = samples.copy()
|
||||||
|
|
||||||
|
s1 = samples["samples"]
|
||||||
|
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||||
|
"WanVaceToVideo": WanVaceToVideo,
|
||||||
|
"TrimVideoLatent": TrimVideoLatent,
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user