Compare commits
81 Commits
desktop-re
...
yoland68-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3fa3d269a | ||
|
|
d9c80a85e5 | ||
|
|
3e62c5513a | ||
|
|
cd18582578 | ||
|
|
80a44b97f5 | ||
|
|
9187a09483 | ||
|
|
3041e5c354 | ||
|
|
7689917113 | ||
|
|
486ad8fdc5 | ||
|
|
065d855f14 | ||
|
|
530494588d | ||
|
|
2ab9618732 | ||
|
|
d9a87c1e6a | ||
|
|
551fe8dcee | ||
|
|
ff99861650 | ||
|
|
8d0661d0ba | ||
|
|
6d32dc049e | ||
|
|
aa9d759df3 | ||
|
|
c6c19e9980 | ||
|
|
08ff5fa08a | ||
|
|
4ca3d84277 | ||
|
|
39c27a3705 | ||
|
|
b1c7291569 | ||
|
|
dbc726f80c | ||
|
|
7ee96455e2 | ||
|
|
0a66d4b0af | ||
|
|
5c5457a4ef | ||
|
|
45503f6499 | ||
|
|
005a91ce2b | ||
|
|
68f0d35296 | ||
|
|
83d04717b6 | ||
|
|
7d329771f9 | ||
|
|
c15909bb62 | ||
|
|
772b4c5945 | ||
|
|
5a50c3c7e5 | ||
|
|
30159a7fe6 | ||
|
|
cb9ac3db58 | ||
|
|
8115a7895b | ||
|
|
c8cd7ad795 | ||
|
|
542b4b36b6 | ||
|
|
ac10a0d69e | ||
|
|
0dcc75ca54 | ||
|
|
b685b8a4e0 | ||
|
|
23e39f2ba7 | ||
|
|
78992c4b25 | ||
|
|
f935d42d8e | ||
|
|
a97f2f850a | ||
|
|
5acb705857 | ||
|
|
5c80da31db | ||
|
|
e2eed9eb9b | ||
|
|
11b68ebd22 | ||
|
|
188b383c35 | ||
|
|
2c1d686ec6 | ||
|
|
e8ddc2be95 | ||
|
|
dea1c7474a | ||
|
|
154f2911aa | ||
|
|
3eaad0590e | ||
|
|
7eaff81be1 | ||
|
|
21a11ef817 | ||
|
|
552615235d | ||
|
|
0738e4ea5d | ||
|
|
92cdc692f4 | ||
|
|
2d6805ce57 | ||
|
|
a8f63c0d5b | ||
|
|
454a635c1b | ||
|
|
966c43ce26 | ||
|
|
3ab231f01f | ||
|
|
1f3fba2af5 | ||
|
|
5d0d4ee98a | ||
|
|
9d57b8afd8 | ||
|
|
5d51794607 | ||
|
|
ce22f687cc | ||
|
|
b6fd3ffd10 | ||
|
|
11b72c9c55 | ||
|
|
2c735c13b4 | ||
|
|
fd27494441 | ||
|
|
f43e1d7f41 | ||
|
|
4486b0d0ff | ||
|
|
636d4bfb89 | ||
|
|
dc300a4569 | ||
|
|
f3b09b9f2d |
@@ -63,7 +63,12 @@ except:
|
||||
print("checking out master branch") # noqa: T201
|
||||
branch = repo.lookup_branch('master')
|
||||
if branch is None:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
try:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
except:
|
||||
print("pulling.") # noqa: T201
|
||||
pull(repo)
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
if branch is None:
|
||||
|
||||
12
.github/workflows/stable-release.yml
vendored
12
.github/workflows/stable-release.yml
vendored
@@ -12,7 +12,7 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
@@ -22,7 +22,7 @@ on:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "10"
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.git_tag }}
|
||||
fetch-depth: 0
|
||||
fetch-depth: 150
|
||||
persist-credentials: false
|
||||
- uses: actions/cache/restore@v4
|
||||
id: cache
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
@@ -85,12 +85,14 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
||||
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
|
||||
2
.github/workflows/test-ci.yml
vendored
2
.github/workflows/test-ci.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
|
||||
47
.github/workflows/update-api-stubs.yml
vendored
Normal file
47
.github/workflows/update-api-stubs.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
name: Generate Pydantic Stubs from api.comfy.org
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-models:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install 'datamodel-code-generator[http]'
|
||||
|
||||
- name: Generate API models
|
||||
run: |
|
||||
datamodel-codegen --use-subclass-enum --url https://api.comfy.org/openapi --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
- name: Check for changes
|
||||
id: git-check
|
||||
run: |
|
||||
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.git-check.outputs.changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v5
|
||||
with:
|
||||
commit-message: 'chore: update API models from OpenAPI spec'
|
||||
title: 'Update API models from api.comfy.org'
|
||||
body: |
|
||||
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
||||
|
||||
Generated automatically by the a Github workflow.
|
||||
branch: update-api-stubs
|
||||
delete-branch: true
|
||||
base: main
|
||||
@@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "10"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||
|
||||
12
.github/workflows/windows_release_package.yml
vendored
12
.github/workflows/windows_release_package.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "10"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-depth: 150
|
||||
persist-credentials: false
|
||||
- shell: bash
|
||||
run: |
|
||||
@@ -67,7 +67,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
@@ -82,12 +82,14 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
||||
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
|
||||
26
CODEOWNERS
26
CODEOWNERS
@@ -5,20 +5,20 @@
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
|
||||
24
README.md
24
README.md
@@ -49,7 +49,6 @@ Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon,
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Image Models
|
||||
@@ -99,6 +98,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
- Builds a new release using the latest stable core version
|
||||
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
- Weekly frontend updates are merged into the core repository
|
||||
- Features are frozen for the upcoming core release
|
||||
- Development continues for the next release cycle
|
||||
|
||||
## Shortcuts
|
||||
|
||||
| Keybind | Explanation |
|
||||
@@ -149,8 +165,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643)
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
@@ -216,9 +230,9 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||
|
||||
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
|
||||
|
||||
@@ -93,16 +93,20 @@ class CustomNodeManager:
|
||||
|
||||
def add_routes(self, routes, webapp, loadedModules):
|
||||
|
||||
example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
|
||||
|
||||
@routes.get("/workflow_templates")
|
||||
async def get_workflow_templates(request):
|
||||
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
||||
files = [
|
||||
file
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes")
|
||||
for file in glob.glob(
|
||||
os.path.join(folder, "*/example_workflows/*.json")
|
||||
)
|
||||
]
|
||||
|
||||
files = []
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
for folder_name in example_workflow_folder_names:
|
||||
pattern = os.path.join(folder, f"*/{folder_name}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
files.extend(matched_files)
|
||||
|
||||
workflow_templates_dict = (
|
||||
{}
|
||||
) # custom_nodes folder name -> example workflow names
|
||||
@@ -118,15 +122,22 @@ class CustomNodeManager:
|
||||
|
||||
# Serve workflow templates from custom nodes.
|
||||
for module_name, module_dir in loadedModules:
|
||||
workflows_dir = os.path.join(module_dir, "example_workflows")
|
||||
if os.path.exists(workflows_dir):
|
||||
webapp.add_routes(
|
||||
[
|
||||
web.static(
|
||||
"/api/workflow_templates/" + module_name, workflows_dir
|
||||
)
|
||||
]
|
||||
)
|
||||
for folder_name in example_workflow_folder_names:
|
||||
workflows_dir = os.path.join(module_dir, folder_name)
|
||||
|
||||
if os.path.exists(workflows_dir):
|
||||
if folder_name != "example_workflows":
|
||||
logging.debug(
|
||||
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
|
||||
folder_name, module_name)
|
||||
|
||||
webapp.add_routes(
|
||||
[
|
||||
web.static(
|
||||
"/api/workflow_templates/" + module_name, workflows_dir
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@routes.get("/i18n")
|
||||
async def get_i18n(request):
|
||||
|
||||
@@ -197,6 +197,112 @@ class UserManager():
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@routes.get("/v2/userdata")
|
||||
async def list_userdata_v2(request):
|
||||
"""
|
||||
List files and directories in a user's data directory.
|
||||
|
||||
This endpoint provides a structured listing of contents within a specified
|
||||
subdirectory of the user's data storage.
|
||||
|
||||
Query Parameters:
|
||||
- path (optional): The relative path within the user's data directory
|
||||
to list. Defaults to the root ('').
|
||||
|
||||
Returns:
|
||||
- 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
|
||||
- 404: If the requested path does not exist.
|
||||
- 403: If the user is invalid.
|
||||
- 500: If there is an error reading the directory contents.
|
||||
- 200: JSON response containing a list of file and directory objects.
|
||||
Each object includes:
|
||||
- name: The name of the file or directory.
|
||||
- type: 'file' or 'directory'.
|
||||
- path: The relative path from the user's data root.
|
||||
- size (for files): The size in bytes.
|
||||
- modified (for files): The last modified timestamp (Unix epoch).
|
||||
"""
|
||||
requested_rel_path = request.rel_url.query.get('path', '')
|
||||
|
||||
# URL-decode the path parameter
|
||||
try:
|
||||
requested_rel_path = parse.unquote(requested_rel_path)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
|
||||
return web.Response(status=400, text="Invalid characters in path parameter")
|
||||
|
||||
|
||||
# Check user validity and get the absolute path for the requested directory
|
||||
try:
|
||||
base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
|
||||
|
||||
if requested_rel_path:
|
||||
target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
|
||||
else:
|
||||
target_abs_path = base_user_path
|
||||
|
||||
except KeyError as e:
|
||||
# Invalid user detected by get_request_user_id inside get_request_user_filepath
|
||||
logging.warning(f"Access denied for user: {e}")
|
||||
return web.Response(status=403, text="Invalid user specified in request")
|
||||
|
||||
|
||||
if not target_abs_path:
|
||||
# Path traversal or other issue detected by get_request_user_filepath
|
||||
return web.Response(status=400, text="Invalid path requested")
|
||||
|
||||
# Handle cases where the user directory or target path doesn't exist
|
||||
if not os.path.exists(target_abs_path):
|
||||
# Check if it's the base user directory that's missing (new user case)
|
||||
if target_abs_path == base_user_path:
|
||||
# It's okay if the base user directory doesn't exist yet, return empty list
|
||||
return web.json_response([])
|
||||
else:
|
||||
# A specific subdirectory was requested but doesn't exist
|
||||
return web.Response(status=404, text="Requested path not found")
|
||||
|
||||
if not os.path.isdir(target_abs_path):
|
||||
return web.Response(status=400, text="Requested path is not a directory")
|
||||
|
||||
results = []
|
||||
try:
|
||||
for root, dirs, files in os.walk(target_abs_path, topdown=True):
|
||||
# Process directories
|
||||
for dir_name in dirs:
|
||||
dir_path = os.path.join(root, dir_name)
|
||||
rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
|
||||
results.append({
|
||||
"name": dir_name,
|
||||
"path": rel_path,
|
||||
"type": "directory"
|
||||
})
|
||||
|
||||
# Process files
|
||||
for file_name in files:
|
||||
file_path = os.path.join(root, file_name)
|
||||
rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
|
||||
entry_info = {
|
||||
"name": file_name,
|
||||
"path": rel_path,
|
||||
"type": "file"
|
||||
}
|
||||
try:
|
||||
stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
|
||||
entry_info["size"] = stats.st_size
|
||||
entry_info["modified"] = stats.st_mtime
|
||||
except OSError as stat_error:
|
||||
logging.warning(f"Could not stat file {file_path}: {stat_error}")
|
||||
pass # Include file with available info
|
||||
results.append(entry_info)
|
||||
except OSError as e:
|
||||
logging.error(f"Error listing directory {target_abs_path}: {e}")
|
||||
return web.Response(status=500, text="Error reading directory contents")
|
||||
|
||||
# Sort results alphabetically, directories first then files
|
||||
results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
def get_user_data_path(request, check_exists = False, param = "file"):
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
|
||||
@@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
@@ -127,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ class Output:
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, TypedDict, Optional
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -48,6 +48,7 @@ class IO(StrEnum):
|
||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
VIDEO = "VIDEO"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
@@ -115,6 +116,15 @@ class InputTypeOptions(TypedDict):
|
||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||
tooltip: NotRequired[str]
|
||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||
socketless: NotRequired[bool]
|
||||
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
|
||||
Available from frontend v1.17.5
|
||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
|
||||
"""
|
||||
widgetType: NotRequired[str]
|
||||
"""Specifies a type to be used for widget initialization if different from the input type.
|
||||
Available from frontend v1.18.0
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: NotRequired[float]
|
||||
@@ -224,6 +234,8 @@ class ComfyNodeABC(ABC):
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -262,7 +274,7 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
OUTPUT_IS_LIST: tuple[bool, ...]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
|
||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||
@@ -281,7 +293,7 @@ class ComfyNodeABC(ABC):
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
RETURN_TYPES: tuple[IO, ...]
|
||||
"""A tuple representing the outputs of this node.
|
||||
|
||||
Usage::
|
||||
@@ -290,12 +302,12 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
RETURN_NAMES: tuple[str, ...]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
OUTPUT_TOOLTIPS: tuple[str, ...]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
@@ -736,6 +736,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||
model_options = model_options.copy()
|
||||
if "global_average_pooling" not in model_options:
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
|
||||
@@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
return x
|
||||
|
||||
|
||||
@@ -1345,28 +1345,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
old_d = None
|
||||
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
if cfg_pp:
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if cfg_pp:
|
||||
d = to_d(x, sigmas[i], uncond_denoised)
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
if i == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
if cfg_pp:
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
else:
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Gradient estimation
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
if cfg_pp:
|
||||
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||
else:
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
|
||||
183
comfy/ldm/chroma/layers.py
Normal file
183
comfy/ldm/chroma/layers.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@classmethod
|
||||
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||
return cls(
|
||||
shift=tensor[:, offset : offset + 1, :],
|
||||
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
271
comfy/ldm/chroma/model.py
Normal file
271
comfy/ldm/chroma/model.py
Normal file
@@ -0,0 +1,271 @@
|
||||
#Original code can be found on: https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
in_dim: int
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
|
||||
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = ChromaParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.in_dim = params.in_dim
|
||||
self.out_dim = params.out_dim
|
||||
self.hidden_dim = params.hidden_dim
|
||||
self.n_layers = params.n_layers
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
# set as nn identity for now, will overwrite it later.
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
in_dim=self.in_dim,
|
||||
hidden_dim=self.hidden_dim,
|
||||
out_dim=self.out_dim,
|
||||
n_layers=self.n_layers,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
|
||||
# This function slices up the modulations tensor which has the following layout:
|
||||
# single : num_single_blocks * 3 elements
|
||||
# double_img : num_double_blocks * 6 elements
|
||||
# double_txt : num_double_blocks * 6 elements
|
||||
# final : 2 elements
|
||||
if block_type == "final":
|
||||
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
|
||||
single_block_count = self.params.depth_single_blocks
|
||||
double_block_count = self.params.depth
|
||||
offset = 3 * idx
|
||||
if block_type == "single":
|
||||
return ChromaModulationOut.from_offset(tensor, offset)
|
||||
# Double block modulations are 6 elements so we double 3 * idx.
|
||||
offset *= 2
|
||||
if block_type in {"double_img", "double_txt"}:
|
||||
# Advance past the single block modulations.
|
||||
offset += 3 * single_block_count
|
||||
if block_type == "double_txt":
|
||||
# Advance past the double block img modulations.
|
||||
offset += 6 * double_block_count
|
||||
return (
|
||||
ChromaModulationOut.from_offset(tensor, offset),
|
||||
ChromaModulationOut.from_offset(tensor, offset + 3),
|
||||
)
|
||||
raise ValueError("Bad block_type")
|
||||
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
# distilled vector guidance
|
||||
mod_index_length = 344
|
||||
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
|
||||
# guidance = guidance *
|
||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
||||
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
|
||||
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
@@ -23,7 +23,6 @@ from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
@@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}):
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
elif name == "R":
|
||||
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||
else:
|
||||
raise ValueError(f"Normalization {name} not found")
|
||||
|
||||
@@ -120,15 +119,15 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_q = nn.Sequential(
|
||||
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[0], norm_dim),
|
||||
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
self.to_k = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[1], norm_dim),
|
||||
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
self.to_v = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[2], norm_dim),
|
||||
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
|
||||
@@ -27,8 +27,6 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
|
||||
|
||||
if self.affline_emb_norm:
|
||||
logging.debug("Building affine embedding normalization layer")
|
||||
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
||||
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||
else:
|
||||
self.affline_norm = nn.Identity()
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
@@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||
|
||||
@@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
@@ -699,10 +699,13 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
bs, c, h, w = x.shape
|
||||
if image_cond is not None:
|
||||
x = torch.cat([x, image_cond], dim=-1)
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
timesteps = t
|
||||
pooled_embeds = y
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
||||
@@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
|
||||
if norm_type == "layer":
|
||||
norm_layer = operations.LayerNorm
|
||||
elif norm_type == "rms":
|
||||
norm_layer = RMSNorm
|
||||
norm_layer = operations.RMSNorm
|
||||
else:
|
||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
@@ -262,8 +261,8 @@ class CrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
@@ -64,8 +64,8 @@ class JointAttention(nn.Module):
|
||||
)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
@@ -431,7 +431,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
||||
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
cap_feat_dim,
|
||||
dim,
|
||||
@@ -457,7 +457,7 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
|
||||
@@ -9,7 +9,6 @@ from einops import repeat
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
@@ -49,8 +48,8 @@ class WanSelfAttention(nn.Module):
|
||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs):
|
||||
r"""
|
||||
@@ -114,7 +113,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context, context_img_len):
|
||||
r"""
|
||||
@@ -220,6 +219,34 @@ class WanAttentionBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
def __init__(
|
||||
self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6,
|
||||
block_id=0,
|
||||
operation_settings={}
|
||||
):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
return c_skip, c
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||
@@ -395,6 +422,7 @@ class WanModel(torch.nn.Module):
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
@@ -457,7 +485,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
patch_size = self.patch_size
|
||||
@@ -471,7 +499,7 @@ class WanModel(torch.nn.Module):
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
@@ -496,3 +524,116 @@ class WanModel(torch.nn.Module):
|
||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||
return u
|
||||
|
||||
|
||||
class VaceWanModel(WanModel):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='vace',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
image_model=None,
|
||||
vace_layers=None,
|
||||
vace_in_dim=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
# Vace
|
||||
if vace_layers is not None:
|
||||
self.vace_layers = vace_layers
|
||||
self.vace_in_dim = vace_in_dim
|
||||
# vace blocks
|
||||
self.vace_blocks = nn.ModuleList([
|
||||
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings)
|
||||
for i in range(self.vace_layers)
|
||||
])
|
||||
|
||||
self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))}
|
||||
# vace patch embeddings
|
||||
self.vace_patch_embedding = operations.Conv3d(
|
||||
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
vace_context,
|
||||
vace_strength=1.0,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||
c = c.flatten(2).transpose(1, 2)
|
||||
|
||||
# arguments
|
||||
x_orig = x
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
ii = self.vace_layers_mapping.get(i, None)
|
||||
if ii is not None:
|
||||
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x += c_skip * vace_strength
|
||||
del c_skip
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
328
comfy/lora.py
328
comfy/lora.py
@@ -20,6 +20,7 @@ from __future__ import annotations
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
import comfy.weight_adapter as weight_adapter
|
||||
import logging
|
||||
import torch
|
||||
|
||||
@@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True):
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers2_lora in lora.keys():
|
||||
A_name = diffusers2_lora
|
||||
B_name = "{}.lora_A.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers3_lora in lora.keys():
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif mochi_lora in lora.keys():
|
||||
A_name = mochi_lora
|
||||
B_name = "{}.lora_A".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
|
||||
######## loha
|
||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||
hada_t1_name = "{}.hada_t1".format(x)
|
||||
hada_t2_name = "{}.hada_t2".format(x)
|
||||
if hada_w1_a_name in lora.keys():
|
||||
hada_t1 = None
|
||||
hada_t2 = None
|
||||
if hada_t1_name in lora.keys():
|
||||
hada_t1 = lora[hada_t1_name]
|
||||
hada_t2 = lora[hada_t2_name]
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
loaded_keys.add(hada_w2_b_name)
|
||||
|
||||
|
||||
######## lokr
|
||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||
|
||||
lokr_w1 = None
|
||||
if lokr_w1_name in lora.keys():
|
||||
lokr_w1 = lora[lokr_w1_name]
|
||||
loaded_keys.add(lokr_w1_name)
|
||||
|
||||
lokr_w2 = None
|
||||
if lokr_w2_name in lora.keys():
|
||||
lokr_w2 = lora[lokr_w2_name]
|
||||
loaded_keys.add(lokr_w2_name)
|
||||
|
||||
lokr_w1_a = None
|
||||
if lokr_w1_a_name in lora.keys():
|
||||
lokr_w1_a = lora[lokr_w1_a_name]
|
||||
loaded_keys.add(lokr_w1_a_name)
|
||||
|
||||
lokr_w1_b = None
|
||||
if lokr_w1_b_name in lora.keys():
|
||||
lokr_w1_b = lora[lokr_w1_b_name]
|
||||
loaded_keys.add(lokr_w1_b_name)
|
||||
|
||||
lokr_w2_a = None
|
||||
if lokr_w2_a_name in lora.keys():
|
||||
lokr_w2_a = lora[lokr_w2_a_name]
|
||||
loaded_keys.add(lokr_w2_a_name)
|
||||
|
||||
lokr_w2_b = None
|
||||
if lokr_w2_b_name in lora.keys():
|
||||
lokr_w2_b = lora[lokr_w2_b_name]
|
||||
loaded_keys.add(lokr_w2_b_name)
|
||||
|
||||
lokr_t2 = None
|
||||
if lokr_t2_name in lora.keys():
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||
|
||||
#glora
|
||||
a1_name = "{}.a1.weight".format(x)
|
||||
a2_name = "{}.a2.weight".format(x)
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
loaded_keys.add(b2_name)
|
||||
for adapter_cls in weight_adapter.adapters:
|
||||
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
|
||||
if adapter is not None:
|
||||
patch_dict[to_load[x]] = adapter
|
||||
loaded_keys.update(adapter.loaded_keys)
|
||||
continue
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
b_norm_name = "{}.b_norm".format(x)
|
||||
@@ -405,29 +279,16 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
|
||||
|
||||
if isinstance(model, comfy.model_base.HiDream):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Pad a tensor to a new shape with zeros.
|
||||
@@ -482,6 +343,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||
|
||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
||||
if output is None:
|
||||
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||
else:
|
||||
weight = output
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
continue
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
@@ -508,157 +379,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
dora_scale = v[5]
|
||||
|
||||
old_glora = False
|
||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||
rank = v[0].shape[0]
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = v[1].shape[0]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / rank
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -786,8 +787,8 @@ class PixArt(BaseModel):
|
||||
return out
|
||||
|
||||
class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
try:
|
||||
@@ -1043,6 +1044,37 @@ class WAN21(BaseModel):
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_Vace(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
noise = kwargs.get("noise", None)
|
||||
noise_shape = list(noise.shape)
|
||||
vace_frames = kwargs.get("vace_frames", None)
|
||||
if vace_frames is None:
|
||||
noise_shape[1] = 32
|
||||
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||
|
||||
for i in range(0, vace_frames.shape[1], 16):
|
||||
vace_frames = vace_frames.clone()
|
||||
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
||||
|
||||
mask = kwargs.get("vace_mask", None)
|
||||
if mask is None:
|
||||
noise_shape[1] = 64
|
||||
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||
|
||||
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
||||
|
||||
vace_strength = kwargs.get("vace_strength", 1.0)
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
@@ -1073,4 +1105,19 @@ class HiDream(BaseModel):
|
||||
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
||||
if conditioning_llama3 is not None:
|
||||
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
||||
image_cond = kwargs.get("concat_latent_image", None)
|
||||
if image_cond is not None:
|
||||
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
|
||||
return out
|
||||
|
||||
class Chroma(Flux):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
guidance = kwargs.get("guidance", 0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
@@ -164,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
@@ -174,7 +176,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
dit_config["out_channels"] = 64
|
||||
dit_config["in_dim"] = 64
|
||||
dit_config["out_dim"] = 3072
|
||||
dit_config["hidden_dim"] = 5120
|
||||
dit_config["n_layers"] = 5
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
@@ -317,10 +328,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["cross_attn_norm"] = True
|
||||
dit_config["eps"] = 1e-6
|
||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "vace"
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
else:
|
||||
dit_config["model_type"] = "t2v"
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
else:
|
||||
dit_config["model_type"] = "t2v"
|
||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||
if flf_weight is not None:
|
||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||
|
||||
@@ -725,6 +725,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
return torch.float8_e4m3fn
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if args.fp8_e8m0fnu_unet:
|
||||
return torch.float8_e8m0fnu
|
||||
|
||||
fp8_dtype = None
|
||||
if weight_dtype in FLOAT8_TYPES:
|
||||
@@ -937,15 +939,61 @@ def force_channels_last():
|
||||
#TODO
|
||||
return False
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
|
||||
STREAMS = {}
|
||||
NUM_STREAMS = 1
|
||||
if args.async_offload:
|
||||
NUM_STREAMS = 2
|
||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||
|
||||
stream_counters = {}
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS <= 1:
|
||||
return None
|
||||
|
||||
if device in STREAMS:
|
||||
ss = STREAMS[device]
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
if is_device_cuda(device):
|
||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return None
|
||||
|
||||
def sync_stream(device, stream):
|
||||
if stream is None:
|
||||
return
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
if stream is not None:
|
||||
with stream:
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
if stream is not None:
|
||||
with stream:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
|
||||
@@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
self.zsnr = zsnr
|
||||
|
||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
if zsnr:
|
||||
if self.zsnr:
|
||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
24
comfy/ops.py
24
comfy/ops.py
@@ -22,6 +22,7 @@ import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
@@ -37,20 +38,31 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
|
||||
if has_function:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
if has_function:
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
with wf_context:
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
return weight, bias
|
||||
|
||||
class CastWeightBiasOp:
|
||||
|
||||
@@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
||||
47
comfy/sd.py
47
comfy/sd.py
@@ -120,6 +120,7 @@ class CLIP:
|
||||
self.layer_idx = None
|
||||
self.use_clip_schedule = False
|
||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||
self.tokenizer_options = {}
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@@ -127,6 +128,7 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
n.layer_idx = self.layer_idx
|
||||
n.tokenizer_options = self.tokenizer_options.copy()
|
||||
n.use_clip_schedule = self.use_clip_schedule
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
@@ -134,10 +136,18 @@ class CLIP:
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
def set_tokenizer_option(self, option_name, value):
|
||||
self.tokenizer_options[option_name] = value
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def tokenize(self, text, return_word_ids=False, **kwargs):
|
||||
tokenizer_options = kwargs.get("tokenizer_options", {})
|
||||
if len(self.tokenizer_options) > 0:
|
||||
tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
|
||||
if len(tokenizer_options) > 0:
|
||||
kwargs["tokenizer_options"] = tokenizer_options
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||
@@ -703,6 +713,8 @@ class CLIPType(Enum):
|
||||
COSMOS = 11
|
||||
LUMINA2 = 12
|
||||
WAN = 13
|
||||
HIDREAM = 14
|
||||
CHROMA = 15
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -791,6 +803,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@@ -804,13 +819,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||
elif clip_type == CLIPType.PIXART:
|
||||
elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA:
|
||||
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
|
||||
elif clip_type == CLIPType.WAN:
|
||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
@@ -827,10 +846,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
@@ -848,6 +875,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
# Detect
|
||||
hidream_dualclip_classes = []
|
||||
for hidream_te in clip_data:
|
||||
te_model = detect_te_model(hidream_te)
|
||||
hidream_dualclip_classes.append(te_model)
|
||||
|
||||
clip_l = TEModel.CLIP_L in hidream_dualclip_classes
|
||||
clip_g = TEModel.CLIP_G in hidream_dualclip_classes
|
||||
t5 = TEModel.T5_XXL in hidream_dualclip_classes
|
||||
llama = TEModel.LLAMA3_8 in hidream_dualclip_classes
|
||||
|
||||
# Initialize t5xxl_detect and llama_detect kwargs if needed
|
||||
t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
|
||||
llama_kwargs = llama_detect(clip_data) if llama else {}
|
||||
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
@@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||
self.min_length = min_length
|
||||
self.end_token = None
|
||||
self.min_padding = min_padding
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.tokenizer_adds_end_token = has_end_token
|
||||
@@ -518,13 +519,15 @@ class SDTokenizer:
|
||||
return (embed, leftover)
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||
'''
|
||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||
'''
|
||||
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
@@ -603,10 +606,12 @@ class SDTokenizer:
|
||||
#fill last batch
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
if min_padding is not None:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
||||
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
if min_length is not None and len(batch) < min_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
@@ -634,7 +639,7 @@ class SD1Tokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
||||
@@ -28,8 +28,8 @@ class SDXLTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
||||
@@ -987,6 +987,20 @@ class WAN21_FunControl2V(WAN21_T2V):
|
||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_Vace(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "vace",
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 1.2 * self.memory_usage_factor
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@@ -1054,7 +1068,34 @@ class HiDream(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
|
||||
class Chroma(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "chroma",
|
||||
}
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||
unet_extra_config = {
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.Flux
|
||||
|
||||
memory_usage_factor = 3.2
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Chroma(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -19,8 +19,8 @@ class FluxTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
||||
@@ -16,11 +16,11 @@ class HiDreamTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
@@ -109,14 +109,18 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
if self.t5xxl is not None:
|
||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
t5_out, t5_pooled = t5_output[:2]
|
||||
else:
|
||||
t5_out = None
|
||||
|
||||
if self.llama is not None:
|
||||
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
ll_out, ll_pooled = ll_output[:2]
|
||||
ll_out = ll_out[:, 1:]
|
||||
else:
|
||||
ll_out = None
|
||||
|
||||
if t5_out is None:
|
||||
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if ll_out is None:
|
||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
@@ -49,13 +49,13 @@ class HunyuanVideoTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
|
||||
embed_count = 0
|
||||
for r in llama_text_tokens:
|
||||
for i in range(len(r)):
|
||||
|
||||
@@ -41,8 +41,8 @@ class HyditTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
||||
@@ -45,9 +45,9 @@ class SD3Tokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
class SPieceTokenizer:
|
||||
@staticmethod
|
||||
@@ -15,6 +16,8 @@ class SPieceTokenizer:
|
||||
if isinstance(tokenizer_path, bytes):
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||
else:
|
||||
if not os.path.isfile(tokenizer_path):
|
||||
raise ValueError("invalid tokenizer")
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||
|
||||
def get_vocab(self):
|
||||
|
||||
17
comfy/weight_adapter/__init__.py
Normal file
17
comfy/weight_adapter/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .base import WeightAdapterBase
|
||||
from .lora import LoRAAdapter
|
||||
from .loha import LoHaAdapter
|
||||
from .lokr import LoKrAdapter
|
||||
from .glora import GLoRAAdapter
|
||||
from .oft import OFTAdapter
|
||||
from .boft import BOFTAdapter
|
||||
|
||||
|
||||
adapters: list[type[WeightAdapterBase]] = [
|
||||
LoRAAdapter,
|
||||
LoHaAdapter,
|
||||
LoKrAdapter,
|
||||
GLoRAAdapter,
|
||||
OFTAdapter,
|
||||
BOFTAdapter,
|
||||
]
|
||||
104
comfy/weight_adapter/base.py
Normal file
104
comfy/weight_adapter/base.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class WeightAdapterBase:
|
||||
name: str
|
||||
loaded_keys: set[str]
|
||||
weights: list[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_train(self) -> "WeightAdapterTrainBase":
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WeightAdapterTrainBase(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# [TODO] Collaborate with LoRA training PR #7032
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
|
||||
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
||||
if wd_on_output_axis:
|
||||
weight_norm = (
|
||||
weight.reshape(weight.shape[0], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||
)
|
||||
else:
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
|
||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Pad a tensor to a new shape with zeros.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The original tensor to be padded.
|
||||
new_shape (List[int]): The desired shape of the padded tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||
|
||||
Note:
|
||||
If the new shape is smaller than the original tensor in any dimension,
|
||||
the original tensor will be truncated in that dimension.
|
||||
"""
|
||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||
|
||||
if len(new_shape) != len(tensor.shape):
|
||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||
|
||||
# Create a new tensor filled with zeros
|
||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
# Create slicing tuples for both tensors
|
||||
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
|
||||
# Copy the original tensor into the new tensor
|
||||
padded_tensor[new_slices] = tensor[orig_slices]
|
||||
|
||||
return padded_tensor
|
||||
115
comfy/weight_adapter/boft.py
Normal file
115
comfy/weight_adapter/boft.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
|
||||
class BOFTAdapter(WeightAdapterBase):
|
||||
name = "boft"
|
||||
|
||||
def __init__(self, loaded_keys, weights):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
loaded_keys: set[str] = None,
|
||||
) -> Optional["BOFTAdapter"]:
|
||||
if loaded_keys is None:
|
||||
loaded_keys = set()
|
||||
blocks_name = "{}.oft_blocks".format(x)
|
||||
rescale_name = "{}.rescale".format(x)
|
||||
|
||||
blocks = None
|
||||
if blocks_name in lora.keys():
|
||||
blocks = lora[blocks_name]
|
||||
if blocks.ndim == 4:
|
||||
loaded_keys.add(blocks_name)
|
||||
else:
|
||||
blocks = None
|
||||
if blocks is None:
|
||||
return None
|
||||
|
||||
rescale = None
|
||||
if rescale_name in lora.keys():
|
||||
rescale = lora[rescale_name]
|
||||
loaded_keys.add(rescale_name)
|
||||
|
||||
weights = (blocks, rescale, alpha, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
v = self.weights
|
||||
blocks = v[0]
|
||||
rescale = v[1]
|
||||
alpha = v[2]
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
if rescale is not None:
|
||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||
|
||||
boft_m, block_num, boft_b, *_ = blocks.shape
|
||||
|
||||
try:
|
||||
# Get r
|
||||
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(-1, -2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
# use float() to prevent unsupported type in .inverse()
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
r = r.to(weight)
|
||||
inp = org = weight
|
||||
|
||||
r_b = boft_b//2
|
||||
for i in range(boft_m):
|
||||
bi = r[i]
|
||||
g = 2
|
||||
k = 2**i * r_b
|
||||
if strength != 1:
|
||||
bi = bi * strength + (1-strength) * I
|
||||
inp = (
|
||||
inp.unflatten(0, (-1, g, k))
|
||||
.transpose(1, 2)
|
||||
.flatten(0, 2)
|
||||
.unflatten(0, (-1, boft_b))
|
||||
)
|
||||
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
||||
inp = (
|
||||
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
|
||||
)
|
||||
|
||||
if rescale is not None:
|
||||
inp = inp * rescale
|
||||
|
||||
lora_diff = inp - org
|
||||
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
93
comfy/weight_adapter/glora.py
Normal file
93
comfy/weight_adapter/glora.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
|
||||
class GLoRAAdapter(WeightAdapterBase):
|
||||
name = "glora"
|
||||
|
||||
def __init__(self, loaded_keys, weights):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
loaded_keys: set[str] = None,
|
||||
) -> Optional["GLoRAAdapter"]:
|
||||
if loaded_keys is None:
|
||||
loaded_keys = set()
|
||||
a1_name = "{}.a1.weight".format(x)
|
||||
a2_name = "{}.a2.weight".format(x)
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
loaded_keys.add(b2_name)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
v = self.weights
|
||||
dora_scale = v[5]
|
||||
|
||||
old_glora = False
|
||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||
rank = v[0].shape[0]
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = v[1].shape[0]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / rank
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
100
comfy/weight_adapter/loha.py
Normal file
100
comfy/weight_adapter/loha.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
|
||||
class LoHaAdapter(WeightAdapterBase):
|
||||
name = "loha"
|
||||
|
||||
def __init__(self, loaded_keys, weights):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
loaded_keys: set[str] = None,
|
||||
) -> Optional["LoHaAdapter"]:
|
||||
if loaded_keys is None:
|
||||
loaded_keys = set()
|
||||
|
||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||
hada_t1_name = "{}.hada_t1".format(x)
|
||||
hada_t2_name = "{}.hada_t2".format(x)
|
||||
if hada_w1_a_name in lora.keys():
|
||||
hada_t1 = None
|
||||
hada_t2 = None
|
||||
if hada_t1_name in lora.keys():
|
||||
hada_t1 = lora[hada_t1_name]
|
||||
hada_t2 = lora[hada_t2_name]
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
loaded_keys.add(hada_w2_b_name)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
v = self.weights
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
133
comfy/weight_adapter/lokr.py
Normal file
133
comfy/weight_adapter/lokr.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
|
||||
class LoKrAdapter(WeightAdapterBase):
|
||||
name = "lokr"
|
||||
|
||||
def __init__(self, loaded_keys, weights):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
loaded_keys: set[str] = None,
|
||||
) -> Optional["LoKrAdapter"]:
|
||||
if loaded_keys is None:
|
||||
loaded_keys = set()
|
||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||
|
||||
lokr_w1 = None
|
||||
if lokr_w1_name in lora.keys():
|
||||
lokr_w1 = lora[lokr_w1_name]
|
||||
loaded_keys.add(lokr_w1_name)
|
||||
|
||||
lokr_w2 = None
|
||||
if lokr_w2_name in lora.keys():
|
||||
lokr_w2 = lora[lokr_w2_name]
|
||||
loaded_keys.add(lokr_w2_name)
|
||||
|
||||
lokr_w1_a = None
|
||||
if lokr_w1_a_name in lora.keys():
|
||||
lokr_w1_a = lora[lokr_w1_a_name]
|
||||
loaded_keys.add(lokr_w1_a_name)
|
||||
|
||||
lokr_w1_b = None
|
||||
if lokr_w1_b_name in lora.keys():
|
||||
lokr_w1_b = lora[lokr_w1_b_name]
|
||||
loaded_keys.add(lokr_w1_b_name)
|
||||
|
||||
lokr_w2_a = None
|
||||
if lokr_w2_a_name in lora.keys():
|
||||
lokr_w2_a = lora[lokr_w2_a_name]
|
||||
loaded_keys.add(lokr_w2_a_name)
|
||||
|
||||
lokr_w2_b = None
|
||||
if lokr_w2_b_name in lora.keys():
|
||||
lokr_w2_b = lora[lokr_w2_b_name]
|
||||
loaded_keys.add(lokr_w2_b_name)
|
||||
|
||||
lokr_t2 = None
|
||||
if lokr_t2_name in lora.keys():
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
v = self.weights
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
142
comfy/weight_adapter/lora.py
Normal file
142
comfy/weight_adapter/lora.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||
|
||||
|
||||
class LoRAAdapter(WeightAdapterBase):
|
||||
name = "lora"
|
||||
|
||||
def __init__(self, loaded_keys, weights):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
loaded_keys: set[str] = None,
|
||||
) -> Optional["LoRAAdapter"]:
|
||||
if loaded_keys is None:
|
||||
loaded_keys = set()
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers2_lora in lora.keys():
|
||||
A_name = diffusers2_lora
|
||||
B_name = "{}.lora_A.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers3_lora in lora.keys():
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif mochi_lora in lora.keys():
|
||||
A_name = mochi_lora
|
||||
B_name = "{}.lora_A".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
v = self.weights
|
||||
mat1 = comfy.model_management.cast_to_device(
|
||||
v[0], weight.device, intermediate_dtype
|
||||
)
|
||||
mat2 = comfy.model_management.cast_to_device(
|
||||
v[1], weight.device, intermediate_dtype
|
||||
)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = comfy.model_management.cast_to_device(
|
||||
v[3], weight.device, intermediate_dtype
|
||||
)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = (
|
||||
torch.mm(
|
||||
mat2.transpose(0, 1).flatten(start_dim=1),
|
||||
mat3.transpose(0, 1).flatten(start_dim=1),
|
||||
)
|
||||
.reshape(final_shape)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
try:
|
||||
lora_diff = torch.mm(
|
||||
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
||||
).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
96
comfy/weight_adapter/oft.py
Normal file
96
comfy/weight_adapter/oft.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
|
||||
class OFTAdapter(WeightAdapterBase):
|
||||
name = "oft"
|
||||
|
||||
def __init__(self, loaded_keys, weights):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
loaded_keys: set[str] = None,
|
||||
) -> Optional["OFTAdapter"]:
|
||||
if loaded_keys is None:
|
||||
loaded_keys = set()
|
||||
blocks_name = "{}.oft_blocks".format(x)
|
||||
rescale_name = "{}.rescale".format(x)
|
||||
|
||||
blocks = None
|
||||
if blocks_name in lora.keys():
|
||||
blocks = lora[blocks_name]
|
||||
if blocks.ndim == 3:
|
||||
loaded_keys.add(blocks_name)
|
||||
else:
|
||||
blocks = None
|
||||
if blocks is None:
|
||||
return None
|
||||
|
||||
rescale = None
|
||||
if rescale_name in lora.keys():
|
||||
rescale = lora[rescale_name]
|
||||
loaded_keys.add(rescale_name)
|
||||
|
||||
weights = (blocks, rescale, alpha, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
v = self.weights
|
||||
blocks = v[0]
|
||||
rescale = v[1]
|
||||
alpha = v[2]
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
if rescale is not None:
|
||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||
|
||||
block_num, block_size, *_ = blocks.shape
|
||||
|
||||
try:
|
||||
# Get r
|
||||
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
# use float() to prevent unsupported type in .inverse()
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
r = r.to(weight)
|
||||
_, *shape = weight.shape
|
||||
lora_diff = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
(r * strength) - strength * I,
|
||||
weight.view(block_num, block_size, *shape),
|
||||
).view(-1, *shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
8
comfy_api/input/__init__.py
Normal file
8
comfy_api/input/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .basic_types import ImageInput, AudioInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
]
|
||||
20
comfy_api/input/basic_types.py
Normal file
20
comfy_api/input/basic_types.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from typing import TypedDict
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
class AudioInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing audio input.
|
||||
"""
|
||||
|
||||
waveform: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
sample_rate: int
|
||||
|
||||
45
comfy_api/input/video_types.py
Normal file
45
comfy_api/input/video_types.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_components(self) -> VideoComponents:
|
||||
"""
|
||||
Abstract method to get the video components (images, audio, and frame rate).
|
||||
|
||||
Returns:
|
||||
VideoComponents containing images, audio, and frame rate
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
7
comfy_api/input_impl/__init__.py
Normal file
7
comfy_api/input_impl/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
]
|
||||
271
comfy_api/input_impl/video_types.py
Normal file
271
comfy_api/input_impl/video_types.py
Normal file
@@ -0,0 +1,271 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import AudioInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Add our new metadata
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
output_container.metadata[key] = value
|
||||
else:
|
||||
output_container.metadata[key] = json.dumps(value)
|
||||
|
||||
# Add streams to the new container
|
||||
stream_map = {}
|
||||
for stream in streams:
|
||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||
stream_map[stream] = out_stream
|
||||
|
||||
# Write packets to the new container
|
||||
for packet in container.demux():
|
||||
if packet.stream in stream_map and packet.dts is not None:
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
8
comfy_api/util/__init__.py
Normal file
8
comfy_api/util/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
]
|
||||
51
comfy_api/util/video_types.py
Normal file
51
comfy_api/util/video_types.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
H264 = "h264"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of codec names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of container names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, value) -> str:
|
||||
"""
|
||||
Returns the file extension for the container.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
value = cls(value)
|
||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||
return "mp4"
|
||||
return ""
|
||||
|
||||
@dataclass
|
||||
class VideoComponents:
|
||||
"""
|
||||
Dataclass representing the components of a video.
|
||||
"""
|
||||
|
||||
images: ImageInput
|
||||
frame_rate: Fraction
|
||||
audio: Optional[AudioInput] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
0
comfy_api_nodes/__init__.py
Normal file
0
comfy_api_nodes/__init__.py
Normal file
17
comfy_api_nodes/apis/PixverseController.py
Normal file
17
comfy_api_nodes/apis/PixverseController.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: https://api.comfy.org/openapi
|
||||
# timestamp: 2025-04-23T15:56:33+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import PixverseDto
|
||||
|
||||
|
||||
class ResponseData(BaseModel):
|
||||
ErrCode: Optional[int] = None
|
||||
ErrMsg: Optional[str] = None
|
||||
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
||||
57
comfy_api_nodes/apis/PixverseDto.py
Normal file
57
comfy_api_nodes/apis/PixverseDto.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: https://api.comfy.org/openapi
|
||||
# timestamp: 2025-04-23T15:56:33+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, constr
|
||||
|
||||
|
||||
class V2OpenAPII2VResp(BaseModel):
|
||||
video_id: Optional[int] = Field(None, description='Video_id')
|
||||
|
||||
|
||||
class V2OpenAPIT2VReq(BaseModel):
|
||||
aspect_ratio: str = Field(
|
||||
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
||||
)
|
||||
duration: int = Field(
|
||||
...,
|
||||
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
||||
examples=[5],
|
||||
)
|
||||
model: str = Field(
|
||||
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
||||
)
|
||||
motion_mode: Optional[str] = Field(
|
||||
'normal',
|
||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
||||
examples=['normal'],
|
||||
)
|
||||
negative_prompt: Optional[constr(max_length=2048)] = Field(
|
||||
None, description='Negative prompt\n'
|
||||
)
|
||||
prompt: constr(max_length=2048) = Field(..., description='Prompt')
|
||||
quality: str = Field(
|
||||
...,
|
||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
||||
examples=['540p'],
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
||||
style: Optional[str] = Field(
|
||||
None,
|
||||
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
||||
examples=['anime'],
|
||||
)
|
||||
template_id: Optional[int] = Field(
|
||||
None,
|
||||
description='Template ID (template_id must be activated before use)',
|
||||
examples=[302325299692608],
|
||||
)
|
||||
water_mark: Optional[bool] = Field(
|
||||
False,
|
||||
description='Watermark (true: add watermark, false: no watermark)',
|
||||
examples=[False],
|
||||
)
|
||||
422
comfy_api_nodes/apis/__init__.py
Normal file
422
comfy_api_nodes/apis/__init__.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: https://api.comfy.org/openapi
|
||||
# timestamp: 2025-04-23T15:56:33+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field, confloat, conint
|
||||
|
||||
class Customer(BaseModel):
|
||||
createdAt: Optional[datetime] = Field(
|
||||
None, description='The date and time the user was created'
|
||||
)
|
||||
email: Optional[str] = Field(None, description='The email address for this user')
|
||||
id: str = Field(..., description='The firebase UID of the user')
|
||||
name: Optional[str] = Field(None, description='The name for this user')
|
||||
updatedAt: Optional[datetime] = Field(
|
||||
None, description='The date and time the user was last updated'
|
||||
)
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
details: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Optional detailed information about the error or hints for resolving it.',
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
None, description='A clear and concise description of the error.'
|
||||
)
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: str
|
||||
message: str
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
aspect_ratio: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
|
||||
)
|
||||
color_palette: Optional[Dict[str, Any]] = Field(
|
||||
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
|
||||
)
|
||||
magic_prompt_option: Optional[str] = Field(
|
||||
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
|
||||
)
|
||||
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None,
|
||||
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
|
||||
)
|
||||
num_images: Optional[conint(ge=1, le=8)] = Field(
|
||||
1, description='Optional. Number of images to generate (1-8). Defaults to 1.'
|
||||
)
|
||||
prompt: str = Field(
|
||||
..., description='Required. The prompt to use to generate the image.'
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
|
||||
)
|
||||
seed: Optional[conint(ge=0, le=2147483647)] = Field(
|
||||
None, description='Optional. A number between 0 and 2147483647.'
|
||||
)
|
||||
style_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
|
||||
)
|
||||
|
||||
|
||||
class Datum(BaseModel):
|
||||
is_image_safe: Optional[bool] = Field(
|
||||
None, description='Indicates whether the image is considered safe.'
|
||||
)
|
||||
prompt: Optional[str] = Field(
|
||||
None, description='The prompt used to generate this image.'
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None, description="The resolution of the generated image (e.g., '1024x1024')."
|
||||
)
|
||||
seed: Optional[int] = Field(
|
||||
None, description='The seed value used for this generation.'
|
||||
)
|
||||
style_type: Optional[str] = Field(
|
||||
None,
|
||||
description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').",
|
||||
)
|
||||
url: Optional[str] = Field(None, description='URL to the generated image.')
|
||||
|
||||
|
||||
class Code(Enum):
|
||||
int_1100 = 1100
|
||||
int_1101 = 1101
|
||||
int_1102 = 1102
|
||||
int_1103 = 1103
|
||||
|
||||
|
||||
class Code1(Enum):
|
||||
int_1000 = 1000
|
||||
int_1001 = 1001
|
||||
int_1002 = 1002
|
||||
int_1003 = 1003
|
||||
int_1004 = 1004
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
field_16_9 = '16:9'
|
||||
field_9_16 = '9:16'
|
||||
field_1_1 = '1:1'
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
horizontal: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||
pan: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||
roll: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||
tilt: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||
vertical: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||
zoom: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||
|
||||
|
||||
class Type(str, Enum):
|
||||
simple = 'simple'
|
||||
down_back = 'down_back'
|
||||
forward_up = 'forward_up'
|
||||
right_turn_forward = 'right_turn_forward'
|
||||
left_turn_forward = 'left_turn_forward'
|
||||
|
||||
|
||||
class CameraControl(BaseModel):
|
||||
config: Optional[Config] = None
|
||||
type: Optional[Type] = Field(None, description='Predefined camera movements type')
|
||||
|
||||
|
||||
class Duration(str, Enum):
|
||||
field_5 = 5
|
||||
field_10 = 10
|
||||
|
||||
|
||||
class Mode(str, Enum):
|
||||
std = 'std'
|
||||
pro = 'pro'
|
||||
|
||||
|
||||
class TaskInfo(BaseModel):
|
||||
external_task_id: Optional[str] = None
|
||||
|
||||
|
||||
class Video(BaseModel):
|
||||
duration: Optional[str] = Field(None, description='Total video duration')
|
||||
id: Optional[str] = Field(None, description='Generated video ID')
|
||||
url: Optional[AnyUrl] = Field(None, description='URL for generated video')
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
videos: Optional[List[Video]] = None
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
submitted = 'submitted'
|
||||
processing = 'processing'
|
||||
succeed = 'succeed'
|
||||
failed = 'failed'
|
||||
|
||||
|
||||
class Data(BaseModel):
|
||||
created_at: Optional[int] = Field(None, description='Task creation time')
|
||||
task_id: Optional[str] = Field(None, description='Task ID')
|
||||
task_info: Optional[TaskInfo] = None
|
||||
task_result: Optional[TaskResult] = None
|
||||
task_status: Optional[TaskStatus] = None
|
||||
updated_at: Optional[int] = Field(None, description='Task update time')
|
||||
|
||||
|
||||
class AspectRatio1(str, Enum):
|
||||
field_16_9 = '16:9'
|
||||
field_9_16 = '9:16'
|
||||
field_1_1 = '1:1'
|
||||
field_4_3 = '4:3'
|
||||
field_3_4 = '3:4'
|
||||
field_3_2 = '3:2'
|
||||
field_2_3 = '2:3'
|
||||
field_21_9 = '21:9'
|
||||
|
||||
|
||||
class ImageReference(str, Enum):
|
||||
subject = 'subject'
|
||||
face = 'face'
|
||||
|
||||
|
||||
class Image(BaseModel):
|
||||
index: Optional[int] = Field(None, description='Image Number (0-9)')
|
||||
url: Optional[AnyUrl] = Field(None, description='URL for generated image')
|
||||
|
||||
|
||||
class TaskResult1(BaseModel):
|
||||
images: Optional[List[Image]] = None
|
||||
|
||||
|
||||
class Data1(BaseModel):
|
||||
created_at: Optional[int] = Field(None, description='Task creation time')
|
||||
task_id: Optional[str] = Field(None, description='Task ID')
|
||||
task_result: Optional[TaskResult1] = None
|
||||
task_status: Optional[TaskStatus] = None
|
||||
task_status_msg: Optional[str] = Field(None, description='Task status information')
|
||||
updated_at: Optional[int] = Field(None, description='Task update time')
|
||||
|
||||
|
||||
class AspectRatio2(str, Enum):
|
||||
field_16_9 = '16:9'
|
||||
field_9_16 = '9:16'
|
||||
field_1_1 = '1:1'
|
||||
|
||||
|
||||
class CameraControl1(BaseModel):
|
||||
config: Optional[Config] = None
|
||||
type: Optional[Type] = Field(None, description='Predefined camera movements type')
|
||||
|
||||
|
||||
class ModelName2(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
|
||||
|
||||
class TaskResult2(BaseModel):
|
||||
videos: Optional[List[Video]] = None
|
||||
|
||||
|
||||
class Data2(BaseModel):
|
||||
created_at: Optional[int] = Field(None, description='Task creation time')
|
||||
task_id: Optional[str] = Field(None, description='Task ID')
|
||||
task_info: Optional[TaskInfo] = None
|
||||
task_result: Optional[TaskResult2] = None
|
||||
task_status: Optional[TaskStatus] = None
|
||||
updated_at: Optional[int] = Field(None, description='Task update time')
|
||||
|
||||
|
||||
class Code2(Enum):
|
||||
int_1200 = 1200
|
||||
int_1201 = 1201
|
||||
int_1202 = 1202
|
||||
int_1203 = 1203
|
||||
|
||||
|
||||
class ResourcePackType(str, Enum):
|
||||
decreasing_total = 'decreasing_total'
|
||||
constant_period = 'constant_period'
|
||||
|
||||
|
||||
class Status(str, Enum):
|
||||
toBeOnline = 'toBeOnline'
|
||||
online = 'online'
|
||||
expired = 'expired'
|
||||
runOut = 'runOut'
|
||||
|
||||
|
||||
class ResourcePackSubscribeInfo(BaseModel):
|
||||
effective_time: Optional[int] = Field(
|
||||
None, description='Effective time, Unix timestamp in ms'
|
||||
)
|
||||
invalid_time: Optional[int] = Field(
|
||||
None, description='Expiration time, Unix timestamp in ms'
|
||||
)
|
||||
purchase_time: Optional[int] = Field(
|
||||
None, description='Purchase time, Unix timestamp in ms'
|
||||
)
|
||||
remaining_quantity: Optional[float] = Field(
|
||||
None, description='Remaining quantity (updated with a 12-hour delay)'
|
||||
)
|
||||
resource_pack_id: Optional[str] = Field(None, description='Resource package ID')
|
||||
resource_pack_name: Optional[str] = Field(None, description='Resource package name')
|
||||
resource_pack_type: Optional[ResourcePackType] = Field(
|
||||
None,
|
||||
description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)',
|
||||
)
|
||||
status: Optional[Status] = Field(None, description='Resource Package Status')
|
||||
total_quantity: Optional[float] = Field(None, description='Total quantity')
|
||||
|
||||
class Background(str, Enum):
|
||||
transparent = 'transparent'
|
||||
opaque = 'opaque'
|
||||
|
||||
|
||||
class Moderation(str, Enum):
|
||||
low = 'low'
|
||||
auto = 'auto'
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
png = 'png'
|
||||
webp = 'webp'
|
||||
jpeg = 'jpeg'
|
||||
|
||||
|
||||
class Quality(str, Enum):
|
||||
low = 'low'
|
||||
medium = 'medium'
|
||||
high = 'high'
|
||||
|
||||
|
||||
class OpenAIImageEditRequest(BaseModel):
|
||||
background: Optional[str] = Field(
|
||||
None, description='Background transparency', examples=['opaque']
|
||||
)
|
||||
model: str = Field(
|
||||
..., description='The model to use for image editing', examples=['gpt-image-1']
|
||||
)
|
||||
moderation: Optional[Moderation] = Field(
|
||||
None, description='Content moderation setting', examples=['auto']
|
||||
)
|
||||
n: Optional[int] = Field(
|
||||
None, description='The number of images to generate', examples=[1]
|
||||
)
|
||||
output_compression: Optional[int] = Field(
|
||||
None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
|
||||
)
|
||||
output_format: Optional[OutputFormat] = Field(
|
||||
None, description='Format of the output image', examples=['png']
|
||||
)
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description='A text description of the desired edit',
|
||||
examples=['Give the rocketship rainbow coloring'],
|
||||
)
|
||||
quality: Optional[str] = Field(
|
||||
None, description='The quality of the edited image', examples=['low']
|
||||
)
|
||||
size: Optional[str] = Field(
|
||||
None, description='Size of the output image', examples=['1024x1024']
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
None,
|
||||
description='A unique identifier for end-user monitoring',
|
||||
examples=['user-1234'],
|
||||
)
|
||||
|
||||
|
||||
class Quality1(str, Enum):
|
||||
low = 'low'
|
||||
medium = 'medium'
|
||||
high = 'high'
|
||||
standard = 'standard'
|
||||
hd = 'hd'
|
||||
|
||||
|
||||
class ResponseFormat(str, Enum):
|
||||
url = 'url'
|
||||
b64_json = 'b64_json'
|
||||
|
||||
|
||||
class Style(str, Enum):
|
||||
vivid = 'vivid'
|
||||
natural = 'natural'
|
||||
|
||||
|
||||
class OpenAIImageGenerationRequest(BaseModel):
|
||||
background: Optional[Background] = Field(
|
||||
None, description='Background transparency', examples=['opaque']
|
||||
)
|
||||
model: Optional[str] = Field(
|
||||
None, description='The model to use for image generation', examples=['dall-e-3']
|
||||
)
|
||||
moderation: Optional[Moderation] = Field(
|
||||
None, description='Content moderation setting', examples=['auto']
|
||||
)
|
||||
n: Optional[int] = Field(
|
||||
None,
|
||||
description='The number of images to generate (1-10). Only 1 supported for dall-e-3.',
|
||||
examples=[1],
|
||||
)
|
||||
output_compression: Optional[int] = Field(
|
||||
None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
|
||||
)
|
||||
output_format: Optional[OutputFormat] = Field(
|
||||
None, description='Format of the output image', examples=['png']
|
||||
)
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description='A text description of the desired image',
|
||||
examples=['Draw a rocket in front of a blackhole in deep space'],
|
||||
)
|
||||
quality: Optional[Quality1] = Field(
|
||||
None, description='The quality of the generated image', examples=['high']
|
||||
)
|
||||
response_format: Optional[ResponseFormat] = Field(
|
||||
None, description='Response format of image data', examples=['b64_json']
|
||||
)
|
||||
size: Optional[str] = Field(
|
||||
None,
|
||||
description='Size of the image (e.g., 1024x1024, 1536x1024, auto)',
|
||||
examples=['1024x1536'],
|
||||
)
|
||||
style: Optional[Style] = Field(
|
||||
None, description='Style of the image (only for dall-e-3)', examples=['vivid']
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
None,
|
||||
description='A unique identifier for end-user monitoring',
|
||||
examples=['user-1234'],
|
||||
)
|
||||
|
||||
|
||||
class Datum1(BaseModel):
|
||||
b64_json: Optional[str] = Field(None, description='Base64 encoded image data')
|
||||
revised_prompt: Optional[str] = Field(None, description='Revised prompt')
|
||||
url: Optional[str] = Field(None, description='URL of the image')
|
||||
|
||||
|
||||
class OpenAIImageGenerationResponse(BaseModel):
|
||||
data: Optional[List[Datum1]] = None
|
||||
class User(BaseModel):
|
||||
email: Optional[str] = Field(None, description='The email address for this user.')
|
||||
id: Optional[str] = Field(None, description='The unique id for this user.')
|
||||
isAdmin: Optional[bool] = Field(
|
||||
None, description='Indicates if the user has admin privileges.'
|
||||
)
|
||||
isApproved: Optional[bool] = Field(
|
||||
None, description='Indicates if the user is approved.'
|
||||
)
|
||||
name: Optional[str] = Field(None, description='The name for this user.')
|
||||
341
comfy_api_nodes/apis/client.py
Normal file
341
comfy_api_nodes/apis/client.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import logging
|
||||
|
||||
"""
|
||||
API Client Framework for api.comfy.org.
|
||||
|
||||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
||||
It supports both synchronous and asynchronous API operations with proper type validation.
|
||||
|
||||
Key Components:
|
||||
--------------
|
||||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
||||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
||||
3. ApiOperation - Executes a single synchronous API operation
|
||||
|
||||
Usage Examples:
|
||||
--------------
|
||||
|
||||
# Example 1: Synchronous API Operation
|
||||
# ------------------------------------
|
||||
# For a simple API call that returns the result immediately:
|
||||
|
||||
# 1. Create the API client
|
||||
api_client = ApiClient(
|
||||
base_url="https://api.example.com",
|
||||
api_key="your_api_key_here",
|
||||
timeout=30.0,
|
||||
verify_ssl=True
|
||||
)
|
||||
|
||||
# 2. Define the endpoint
|
||||
user_info_endpoint = ApiEndpoint(
|
||||
path="/v1/users/me",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest, # No request body needed
|
||||
response_model=UserProfile, # Pydantic model for the response
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 3. Create the request object
|
||||
request = EmptyRequest()
|
||||
|
||||
# 4. Create and execute the operation
|
||||
operation = ApiOperation(
|
||||
endpoint=user_info_endpoint,
|
||||
request=request
|
||||
)
|
||||
user_profile = operation.execute(client=api_client) # Returns immediately with the result
|
||||
|
||||
"""
|
||||
|
||||
from typing import (
|
||||
Dict,
|
||||
Type,
|
||||
Optional,
|
||||
Any,
|
||||
TypeVar,
|
||||
Generic,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
import json
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Base class for empty request bodies.
|
||||
For GET requests, fields will be sent as query parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HttpMethod(str, Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
"""
|
||||
Client for making HTTP requests to an API with authentication and error handling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
verify_ssl: bool = True,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""Get headers for API requests, including authentication if available"""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
files: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an HTTP request to the API
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: API endpoint path (will be joined with base_url)
|
||||
params: Query parameters
|
||||
json: JSON body data
|
||||
files: Files to upload
|
||||
headers: Additional headers
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If the request fails
|
||||
"""
|
||||
url = urljoin(self.base_url, path)
|
||||
self.check_auth_token(self.api_key)
|
||||
# Combine default headers with any provided headers
|
||||
request_headers = self.get_headers()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
# Let requests handle the content type when files are present.
|
||||
if files:
|
||||
del request_headers["Content-Type"]
|
||||
|
||||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
||||
logging.debug(f"[DEBUG] Files: {files}")
|
||||
logging.debug(f"[DEBUG] Params: {params}")
|
||||
logging.debug(f"[DEBUG] Json: {json}")
|
||||
|
||||
try:
|
||||
# If files are present, use data parameter instead of json
|
||||
if files:
|
||||
form_data = {}
|
||||
if json:
|
||||
form_data.update(json)
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
data=form_data, # Use data instead of json
|
||||
files=files,
|
||||
headers=request_headers,
|
||||
timeout=self.timeout,
|
||||
verify=self.verify_ssl,
|
||||
)
|
||||
else:
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=self.timeout,
|
||||
verify=self.verify_ssl,
|
||||
)
|
||||
|
||||
# Raise exception for error status codes
|
||||
response.raise_for_status()
|
||||
except requests.ConnectionError:
|
||||
raise Exception(
|
||||
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
|
||||
)
|
||||
|
||||
except requests.Timeout:
|
||||
raise Exception(
|
||||
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
|
||||
)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if hasattr(e, "response") else None
|
||||
error_message = f"HTTP Error: {str(e)}"
|
||||
|
||||
# Try to extract detailed error message from JSON response
|
||||
try:
|
||||
if hasattr(e, "response") and e.response.content:
|
||||
error_json = e.response.json()
|
||||
if "error" in error_json and "message" in error_json["error"]:
|
||||
error_message = f"API Error: {error_json['error']['message']}"
|
||||
if "type" in error_json["error"]:
|
||||
error_message += f" (Type: {error_json['error']['type']})"
|
||||
else:
|
||||
error_message = f"API Error: {error_json}"
|
||||
except Exception as json_error:
|
||||
# If we can't parse the JSON, fall back to the original error message
|
||||
logging.debug(f"[DEBUG] Failed to parse error response: {str(json_error)}")
|
||||
|
||||
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
|
||||
if hasattr(e, "response") and e.response.content:
|
||||
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
||||
if status_code == 401:
|
||||
error_message = "Unauthorized: Please login first to use this node."
|
||||
if status_code == 402:
|
||||
error_message = "Payment Required: Please add credits to your account to use this node."
|
||||
if status_code == 409:
|
||||
error_message = "There is a problem with your account. Please contact support@comfy.org. "
|
||||
if status_code == 429:
|
||||
error_message = "Rate Limit Exceeded: Please try again later."
|
||||
raise Exception(error_message)
|
||||
|
||||
# Parse and return JSON response
|
||||
if response.content:
|
||||
return response.json()
|
||||
return {}
|
||||
|
||||
def check_auth_token(self, auth_token):
|
||||
"""Verify that an auth token is present."""
|
||||
if auth_token is None:
|
||||
raise Exception("Unauthorized: Please login first to use this node.")
|
||||
return auth_token
|
||||
|
||||
|
||||
class ApiEndpoint(Generic[T, R]):
|
||||
"""Defines an API endpoint with its request and response types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
method: HttpMethod,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
query_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize an API endpoint definition.
|
||||
|
||||
Args:
|
||||
path: The URL path for this endpoint, can include placeholders like {id}
|
||||
method: The HTTP method to use (GET, POST, etc.)
|
||||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
||||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
||||
query_params: Optional dictionary of query parameters to include in the request
|
||||
"""
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.request_model = request_model
|
||||
self.response_model = response_model
|
||||
self.query_params = query_params or {}
|
||||
|
||||
|
||||
class SynchronousOperation(Generic[T, R]):
|
||||
"""
|
||||
Represents a single synchronous API operation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: ApiEndpoint[T, R],
|
||||
request: T,
|
||||
files: Optional[Dict[str, Any]] = None,
|
||||
api_base: str = "https://api.comfy.org",
|
||||
auth_token: Optional[str] = None,
|
||||
timeout: float = 604800.0,
|
||||
verify_ssl: bool = True,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.request = request
|
||||
self.response = None
|
||||
self.error = None
|
||||
self.api_base = api_base
|
||||
self.auth_token = auth_token
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.files = files
|
||||
def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
"""Execute the API operation using the provided client or create one"""
|
||||
try:
|
||||
# Create client if not provided
|
||||
if client is None:
|
||||
if self.api_base is None:
|
||||
raise ValueError("Either client or api_base must be provided")
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
api_key=self.auth_token,
|
||||
timeout=self.timeout,
|
||||
verify_ssl=self.verify_ssl,
|
||||
)
|
||||
|
||||
# Convert request model to dict, but use None for EmptyRequest
|
||||
request_dict = None if isinstance(self.request, EmptyRequest) else self.request.model_dump(exclude_none=True)
|
||||
if request_dict:
|
||||
for key, value in request_dict.items():
|
||||
if isinstance(value, Enum):
|
||||
request_dict[key] = value.value
|
||||
|
||||
# Debug log for request
|
||||
logging.debug(f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}")
|
||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||||
|
||||
# Make the request
|
||||
resp = client.request(
|
||||
method=self.endpoint.method.value,
|
||||
path=self.endpoint.path,
|
||||
json=request_dict,
|
||||
params=self.endpoint.query_params,
|
||||
files=self.files,
|
||||
)
|
||||
|
||||
# Debug log for response
|
||||
logging.debug("=" * 50)
|
||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||
logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}")
|
||||
logging.debug("=" * 50)
|
||||
|
||||
# Parse and return the response
|
||||
return self._parse_response(resp)
|
||||
|
||||
except Exception as e:
|
||||
logging.debug(f"[DEBUG] API Exception: {str(e)}")
|
||||
raise Exception(str(e))
|
||||
|
||||
def _parse_response(self, resp):
|
||||
"""Parse response data - can be overridden by subclasses"""
|
||||
# The response is already the complete object, don't extract just the "data" field
|
||||
# as that would lose the outer structure (created timestamp, etc.)
|
||||
|
||||
# Parse response using the provided model
|
||||
self.response = self.endpoint.response_model.model_validate(resp)
|
||||
logging.debug(f"[DEBUG] Parsed Response: {self.response}")
|
||||
return self.response
|
||||
449
comfy_api_nodes/nodes_api.py
Normal file
449
comfy_api_nodes/nodes_api.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import base64
|
||||
import io
|
||||
import math
|
||||
from inspect import cleandoc
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api_nodes.apis import (
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation
|
||||
|
||||
|
||||
def downscale_input(image):
|
||||
samples = image.movedim(-1,1)
|
||||
#downscaling input images to roughly the same size as the outputs
|
||||
total = int(1536 * 1024)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1,-1)
|
||||
return s
|
||||
|
||||
def validate_and_cast_response(response):
|
||||
# validate raw JSON response
|
||||
data = response.data
|
||||
if not data or len(data) == 0:
|
||||
raise Exception("No images returned from API endpoint")
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors = []
|
||||
|
||||
# Process each image in the data array
|
||||
for image_data in data:
|
||||
image_url = image_data.url
|
||||
b64_data = image_data.b64_json
|
||||
|
||||
if not image_url and not b64_data:
|
||||
raise Exception("No image was generated in the response")
|
||||
|
||||
if b64_data:
|
||||
img_data = base64.b64decode(b64_data)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
img_response = requests.get(image_url)
|
||||
if img_response.status_code != 200:
|
||||
raise Exception("Failed to download the image")
|
||||
img = Image.open(io.BytesIO(img_response.content))
|
||||
|
||||
img = img.convert("RGBA")
|
||||
|
||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)
|
||||
|
||||
# Add to list of tensors
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
class OpenAIDalle2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||
|
||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||
so download or cache results if you need to keep them.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for DALL·E",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (IO.INT, {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31-1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "not implemented yet in backend",
|
||||
}),
|
||||
"size": (IO.COMBO, {
|
||||
"options": ["256x256", "512x512", "1024x1024"],
|
||||
"default": "1024x1024",
|
||||
"tooltip": "Image size",
|
||||
}),
|
||||
"n": (IO.INT, {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "How many images to generate",
|
||||
}),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
}),
|
||||
"mask": (IO.MASK, {
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(self, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", auth_token=None):
|
||||
model = "dall-e-2"
|
||||
path = "/proxy/openai/images/generations"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binary = None
|
||||
|
||||
if image is not None and mask is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
|
||||
input_tensor = image.squeeze().cpu()
|
||||
height, width, channels = input_tensor.shape
|
||||
rgba_tensor = torch.ones(height, width, 4, device="cpu")
|
||||
rgba_tensor[:, :, :channels] = input_tensor
|
||||
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
rgba_tensor[:,:,3] = (1-mask.squeeze().cpu())
|
||||
|
||||
rgba_tensor = downscale_input(rgba_tensor.unsqueeze(0)).squeeze()
|
||||
|
||||
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr#.getvalue()
|
||||
img_binary.name = "image.png"
|
||||
elif image is not None or mask is not None:
|
||||
raise Exception("Dall-E 2 image editing requires an image AND a mask")
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_class,
|
||||
response_model=OpenAIImageGenerationResponse
|
||||
),
|
||||
request=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size=size,
|
||||
seed=seed,
|
||||
),
|
||||
files={
|
||||
"image": img_binary,
|
||||
} if img_binary else None,
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
class OpenAIDalle3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||
|
||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||
so download or cache results if you need to keep them.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for DALL·E",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (IO.INT, {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31-1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "not implemented yet in backend",
|
||||
}),
|
||||
"quality" : (IO.COMBO, {
|
||||
"options": ["standard","hd"],
|
||||
"default": "standard",
|
||||
"tooltip": "Image quality",
|
||||
}),
|
||||
"style": (IO.COMBO, {
|
||||
"options": ["natural","vivid"],
|
||||
"default": "natural",
|
||||
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
|
||||
}),
|
||||
"size": (IO.COMBO, {
|
||||
"options": ["1024x1024", "1024x1792", "1792x1024"],
|
||||
"default": "1024x1024",
|
||||
"tooltip": "Image size",
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(self, prompt, seed=0, style="natural", quality="standard", size="1024x1024", auth_token=None):
|
||||
model = "dall-e-3"
|
||||
|
||||
# build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/openai/images/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=OpenAIImageGenerationRequest,
|
||||
response_model=OpenAIImageGenerationResponse
|
||||
),
|
||||
request=OpenAIImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
size=size,
|
||||
style=style,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
class OpenAIGPTImage1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||
|
||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||
so download or cache results if you need to keep them.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for GPT Image 1",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (IO.INT, {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31-1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "not implemented yet in backend",
|
||||
}),
|
||||
"quality": (IO.COMBO, {
|
||||
"options": ["low","medium","high"],
|
||||
"default": "low",
|
||||
"tooltip": "Image quality, affects cost and generation time.",
|
||||
}),
|
||||
"background": (IO.COMBO, {
|
||||
"options": ["opaque","transparent"],
|
||||
"default": "opaque",
|
||||
"tooltip": "Return image with or without background",
|
||||
}),
|
||||
"size": (IO.COMBO, {
|
||||
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
"default": "auto",
|
||||
"tooltip": "Image size",
|
||||
}),
|
||||
"n": (IO.INT, {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "How many images to generate",
|
||||
}),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
}),
|
||||
"mask": (IO.MASK, {
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
}),
|
||||
"moderation": (IO.COMBO, {
|
||||
"options": ["low","auto"],
|
||||
"default": "low",
|
||||
"tooltip": "Moderation level",
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None, moderation="low"):
|
||||
model = "gpt-image-1"
|
||||
path = "/proxy/openai/images/generations"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
files = []
|
||||
|
||||
if image is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
|
||||
batch_size = image.shape[0]
|
||||
|
||||
|
||||
for i in range(batch_size):
|
||||
single_image = image[i:i+1]
|
||||
scaled_image = downscale_input(single_image).squeeze()
|
||||
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = f"image_{i}.png"
|
||||
|
||||
img_binaries.append(img_binary)
|
||||
if batch_size == 1:
|
||||
files.append(("image", img_binary))
|
||||
else:
|
||||
files.append(("image[]", img_binary))
|
||||
|
||||
if mask is not None:
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if image is None:
|
||||
raise Exception("Cannot use a mask without an input image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
batch, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:,:,3] = (1-mask.squeeze().cpu())
|
||||
|
||||
scaled_mask = downscale_input(rgba_mask.unsqueeze(0)).squeeze()
|
||||
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format='PNG')
|
||||
mask_img_byte_arr.seek(0)
|
||||
mask_binary = mask_img_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
files.append(("mask", mask_binary))
|
||||
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_class,
|
||||
response_model=OpenAIImageGenerationResponse
|
||||
),
|
||||
request=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
moderation=moderation,
|
||||
),
|
||||
files=files if files else None,
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OpenAIDalle2": OpenAIDalle2,
|
||||
"OpenAIDalle3": OpenAIDalle3,
|
||||
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
}
|
||||
@@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class T5TokenizerOptions:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
||||
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "set_options"
|
||||
|
||||
def set_options(self, clip, min_padding, min_length):
|
||||
clip = clip.clone()
|
||||
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
|
||||
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
|
||||
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
|
||||
|
||||
return (clip, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
|
||||
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
|
||||
"T5TokenizerOptions": T5TokenizerOptions,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
@@ -249,6 +250,55 @@ class SetFirstSigma:
|
||||
sigmas[0] = sigma
|
||||
return (sigmas, )
|
||||
|
||||
class ExtendIntermediateSigmas:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"sigmas": ("SIGMAS", ),
|
||||
"steps": ("INT", {"default": 2, "min": 1, "max": 100}),
|
||||
"start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}),
|
||||
"end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}),
|
||||
"spacing": (['linear', 'cosine', 'sine'],),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "extend"
|
||||
|
||||
def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str):
|
||||
if start_at_sigma < 0:
|
||||
start_at_sigma = float("inf")
|
||||
|
||||
interpolator = {
|
||||
'linear': lambda x: x,
|
||||
'cosine': lambda x: torch.sin(x*math.pi/2),
|
||||
'sine': lambda x: 1 - torch.cos(x*math.pi/2)
|
||||
}[spacing]
|
||||
|
||||
# linear space for our interpolation function
|
||||
x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
|
||||
computed_spacing = interpolator(x)
|
||||
|
||||
extended_sigmas = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma_current = sigmas[i]
|
||||
sigma_next = sigmas[i+1]
|
||||
|
||||
extended_sigmas.append(sigma_current)
|
||||
|
||||
if end_at_sigma <= sigma_current <= start_at_sigma:
|
||||
interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
|
||||
extended_sigmas.extend(interpolated_steps.tolist())
|
||||
|
||||
# Add the last sigma value
|
||||
if len(sigmas) > 0:
|
||||
extended_sigmas.append(sigmas[-1])
|
||||
|
||||
extended_sigmas = torch.FloatTensor(extended_sigmas)
|
||||
|
||||
return (extended_sigmas,)
|
||||
|
||||
class KSamplerSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -735,6 +785,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
"SetFirstSigma": SetFirstSigma,
|
||||
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
|
||||
|
||||
"CFGGuider": CFGGuider,
|
||||
"DualCFGGuider": DualCFGGuider,
|
||||
|
||||
@@ -26,7 +26,30 @@ class QuadrupleCLIPLoader:
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
class CLIPTextEncodeHiDream:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
|
||||
|
||||
tokens = clip.tokenize(clip_g)
|
||||
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
tokens["llama"] = clip.tokenize(llama)["llama"]
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ class LTXVImgToVideo:
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
@@ -46,7 +47,7 @@ class LTXVImgToVideo:
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "generate"
|
||||
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
@@ -59,7 +60,7 @@ class LTXVImgToVideo:
|
||||
dtype=torch.float32,
|
||||
device=latent.device,
|
||||
)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||
|
||||
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
||||
|
||||
@@ -152,6 +153,15 @@ class LTXVAddGuide:
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||
_, latent_idx = self.get_latent_index(
|
||||
cond=positive,
|
||||
latent_length=latent_image.shape[2],
|
||||
guide_length=guiding_latent.shape[2],
|
||||
frame_idx=frame_idx,
|
||||
scale_factors=scale_factors,
|
||||
)
|
||||
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
||||
|
||||
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
|
||||
@@ -385,7 +395,7 @@ def encode_single_frame(output_file, image_array: np.ndarray, crf):
|
||||
container = av.open(output_file, "w", format="mp4")
|
||||
try:
|
||||
stream = container.add_stream(
|
||||
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||
)
|
||||
stream.height = image_array.shape[0]
|
||||
stream.width = image_array.shape[1]
|
||||
|
||||
@@ -3,7 +3,10 @@ import scipy.ndimage
|
||||
import torch
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import folder_paths
|
||||
import random
|
||||
|
||||
import nodes
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||
@@ -362,6 +365,30 @@ class ThresholdMask:
|
||||
mask = (mask > value).float()
|
||||
return (mask,)
|
||||
|
||||
# Mask Preview - original implement from
|
||||
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||
class MaskPreview(nodes.SaveImage):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
self.type = "temp"
|
||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||
self.compress_level = 4
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {"mask": ("MASK",), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "mask"
|
||||
|
||||
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentCompositeMasked": LatentCompositeMasked,
|
||||
@@ -376,6 +403,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"FeatherMask": FeatherMask,
|
||||
"GrowMask": GrowMask,
|
||||
"ThresholdMask": ThresholdMask,
|
||||
"MaskPreview": MaskPreview
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
||||
@@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
metadata["modelspec.predict_key"] = "epsilon"
|
||||
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||
metadata["modelspec.predict_key"] = "v"
|
||||
extra_keys["v_pred"] = torch.tensor([])
|
||||
if getattr(model_sampling, "zsnr", False):
|
||||
extra_keys["ztsnr"] = torch.tensor([])
|
||||
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
@@ -273,7 +276,7 @@ class CLIPSave:
|
||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", ""]:
|
||||
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||
current_clip_sd = {}
|
||||
for x in k:
|
||||
|
||||
@@ -20,13 +20,14 @@ def loglinear_interp(t_steps, num_steps):
|
||||
|
||||
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
|
||||
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
|
||||
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
|
||||
}
|
||||
|
||||
class OptimalStepsScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model_type": (["FLUX", "Wan"], ),
|
||||
{"model_type": (["FLUX", "Wan", "Chroma"], ),
|
||||
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
|
||||
@@ -141,6 +141,7 @@ class Quantize:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
@staticmethod
|
||||
def bayer(im, pal_im, order):
|
||||
def normalized_bayer_matrix(n):
|
||||
if n == 0:
|
||||
|
||||
43
comfy_extras/nodes_preview_any.py
Normal file
43
comfy_extras/nodes_preview_any.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import json
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
|
||||
# Preview Any - original implement from
|
||||
# https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py
|
||||
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||
class PreviewAny():
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {"source": (IO.ANY, {})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "main"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "utils"
|
||||
|
||||
def main(self, source=None):
|
||||
value = 'None'
|
||||
if isinstance(source, str):
|
||||
value = source
|
||||
elif isinstance(source, (int, float, bool)):
|
||||
value = str(source)
|
||||
elif source is not None:
|
||||
try:
|
||||
value = json.dumps(source)
|
||||
except Exception:
|
||||
try:
|
||||
value = str(source)
|
||||
except Exception:
|
||||
value = 'source exists, but could not be serialized.'
|
||||
|
||||
return {"ui": {"text": (value,)}}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PreviewAny": PreviewAny,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PreviewAny": "Preview Any",
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
# Primitive nodes that are evaluated at backend.
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
||||
|
||||
|
||||
@@ -23,7 +25,7 @@ class Int(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.INT, {"control_after_generate": True})},
|
||||
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
@@ -38,7 +40,7 @@ class Float(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.FLOAT, {})},
|
||||
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.FLOAT,)
|
||||
|
||||
@@ -5,9 +5,13 @@ import av
|
||||
import torch
|
||||
import folder_paths
|
||||
import json
|
||||
from typing import Optional, Literal
|
||||
from fractions import Fraction
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||
from comfy_api.input import ImageInput, AudioInput, VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM:
|
||||
def __init__(self):
|
||||
@@ -50,13 +54,15 @@ class SaveWEBM:
|
||||
for x in extra_pnginfo:
|
||||
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"}
|
||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||
stream.width = images.shape[-2]
|
||||
stream.height = images.shape[-3]
|
||||
stream.pix_fmt = "yuv420p"
|
||||
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
|
||||
stream.bit_rate = 0
|
||||
stream.options = {'crf': str(crf)}
|
||||
if codec == "av1":
|
||||
stream.options["preset"] = "6"
|
||||
|
||||
for frame in images:
|
||||
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
||||
@@ -73,7 +79,163 @@ class SaveWEBM:
|
||||
|
||||
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
|
||||
|
||||
class SaveVideo(ComfyNodeABC):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type: Literal["output"] = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
||||
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
||||
"format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_video"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
|
||||
def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
self.output_dir,
|
||||
width,
|
||||
height
|
||||
)
|
||||
results: list[FileLocator] = list()
|
||||
saved_metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = {}
|
||||
if extra_pnginfo is not None:
|
||||
metadata.update(extra_pnginfo)
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=saved_metadata
|
||||
)
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,) } }
|
||||
|
||||
class CreateVideo(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
|
||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
|
||||
},
|
||||
"optional": {
|
||||
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "create_video"
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Create a video from images."
|
||||
|
||||
def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None):
|
||||
return (VideoFromComponents(
|
||||
VideoComponents(
|
||||
images=images,
|
||||
audio=audio,
|
||||
frame_rate=Fraction(fps),
|
||||
)
|
||||
),)
|
||||
|
||||
class GetVideoComponents(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
|
||||
RETURN_NAMES = ("images", "audio", "fps")
|
||||
FUNCTION = "get_components"
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
||||
|
||||
def get_components(self, video: VideoInput):
|
||||
components = video.get_components()
|
||||
|
||||
return (components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
class LoadVideo(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
files = folder_paths.filter_files_content_types(files, ["video"])
|
||||
return {"required":
|
||||
{"file": (sorted(files), {"video_upload": True})},
|
||||
}
|
||||
|
||||
CATEGORY = "image/video"
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "load_video"
|
||||
def load_video(self, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return (VideoFromFile(video_path),)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
mod_time = os.path.getmtime(video_path)
|
||||
# Instead of hashing the file, we can just use the modification time to avoid
|
||||
# rehashing large files.
|
||||
return mod_time
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, file):
|
||||
if not folder_paths.exists_annotated_filepath(file):
|
||||
return "Invalid video file: {}".format(file)
|
||||
|
||||
return True
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveWEBM": SaveWEBM,
|
||||
"SaveVideo": SaveVideo,
|
||||
"CreateVideo": CreateVideo,
|
||||
"GetVideoComponents": GetVideoComponents,
|
||||
"LoadVideo": LoadVideo,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveVideo": "Save Video",
|
||||
"CreateVideo": "Create Video",
|
||||
"GetVideoComponents": "Get Video Components",
|
||||
"LoadVideo": "Load Video",
|
||||
}
|
||||
|
||||
@@ -193,9 +193,116 @@ class WanFunInpaintToVideo:
|
||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||
|
||||
|
||||
class WanVaceToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {"control_video": ("IMAGE", ),
|
||||
"control_masks": ("MASK", ),
|
||||
"reference_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
|
||||
latent_length = ((length - 1) // 4) + 1
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
if control_video.shape[0] < length:
|
||||
control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5)
|
||||
else:
|
||||
control_video = torch.ones((length, height, width, 3)) * 0.5
|
||||
|
||||
if reference_image is not None:
|
||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
reference_image = vae.encode(reference_image[:, :, :, :3])
|
||||
reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1)
|
||||
|
||||
if control_masks is None:
|
||||
mask = torch.ones((length, height, width, 1))
|
||||
else:
|
||||
mask = control_masks
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1)
|
||||
if mask.shape[0] < length:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0)
|
||||
|
||||
control_video = control_video - 0.5
|
||||
inactive = (control_video * (1 - mask)) + 0.5
|
||||
reactive = (control_video * mask) + 0.5
|
||||
|
||||
inactive = vae.encode(inactive[:, :, :, :3])
|
||||
reactive = vae.encode(reactive[:, :, :, :3])
|
||||
control_video_latent = torch.cat((inactive, reactive), dim=1)
|
||||
if reference_image is not None:
|
||||
control_video_latent = torch.cat((reference_image, control_video_latent), dim=2)
|
||||
|
||||
vae_stride = 8
|
||||
height_mask = height // vae_stride
|
||||
width_mask = width // vae_stride
|
||||
mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride)
|
||||
mask = mask.permute(2, 4, 0, 1, 3)
|
||||
mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask)
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0)
|
||||
|
||||
trim_latent = 0
|
||||
if reference_image is not None:
|
||||
mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :])
|
||||
mask = torch.cat((mask_pad, mask), dim=1)
|
||||
latent_length += reference_image.shape[2]
|
||||
trim_latent = reference_image.shape[2]
|
||||
|
||||
mask = mask.unsqueeze(0)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||
|
||||
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent, trim_latent)
|
||||
|
||||
class TrimVideoLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, samples, trim_amount):
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||
return (samples_out,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||
"WanVaceToVideo": WanVaceToVideo,
|
||||
"TrimVideoLatent": TrimVideoLatent,
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ class WebcamCapture(nodes.LoadImage):
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def load_capture(s, image, **kwargs):
|
||||
def load_capture(self, image, **kwargs):
|
||||
return super().load_image(folder_paths.get_annotated_filepath(image))
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.29"
|
||||
__version__ = "0.3.31"
|
||||
|
||||
@@ -144,6 +144,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import time
|
||||
import mimetypes
|
||||
import logging
|
||||
from typing import Literal
|
||||
from typing import Literal, List
|
||||
from collections.abc import Collection
|
||||
|
||||
from comfy.cli_args import args
|
||||
@@ -141,7 +141,7 @@ def get_directory_by_type(type_name: str) -> str | None:
|
||||
return get_input_directory()
|
||||
return None
|
||||
|
||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]:
|
||||
def filter_files_content_types(files: list[str], content_types: List[Literal["image", "video", "audio", "model"]]) -> list[str]:
|
||||
"""
|
||||
Example:
|
||||
files = os.listdir(folder_paths.get_input_directory())
|
||||
|
||||
17
hook_breaker_ac10a0.py
Normal file
17
hook_breaker_ac10a0.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Prevent custom nodes from hooking anything important
|
||||
import comfy.model_management
|
||||
|
||||
HOOK_BREAK = [(comfy.model_management, "cast_to")]
|
||||
|
||||
|
||||
SAVED_FUNCTIONS = []
|
||||
|
||||
|
||||
def save_functions():
|
||||
for f in HOOK_BREAK:
|
||||
SAVED_FUNCTIONS.append((f[0], f[1], getattr(f[0], f[1])))
|
||||
|
||||
|
||||
def restore_functions():
|
||||
for f in SAVED_FUNCTIONS:
|
||||
setattr(f[0], f[1], f[2])
|
||||
7
main.py
7
main.py
@@ -13,7 +13,7 @@ import logging
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
@@ -141,7 +141,7 @@ import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
def cuda_malloc_warning():
|
||||
device = comfy.model_management.get_torch_device()
|
||||
@@ -215,6 +215,7 @@ def prompt_worker(q, server_instance):
|
||||
comfy.model_management.soft_empty_cache()
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
|
||||
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||
@@ -268,7 +269,9 @@ def start_comfyui(asyncio_loop=None):
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
q = execution.PromptQueue(prompt_server)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
|
||||
18
nodes.py
18
nodes.py
@@ -917,7 +917,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -927,7 +927,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@@ -945,7 +945,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -955,7 +955,7 @@ class DualCLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@@ -2258,6 +2258,12 @@ def init_builtin_extra_nodes():
|
||||
"nodes_optimalsteps.py",
|
||||
"nodes_hidream.py",
|
||||
"nodes_fresca.py",
|
||||
"nodes_preview_any.py",
|
||||
]
|
||||
|
||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||
api_nodes_files = [
|
||||
"nodes_api.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@@ -2265,6 +2271,10 @@ def init_builtin_extra_nodes():
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
for node_file in api_nodes_files:
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.29"
|
||||
version = "0.3.31"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
@@ -12,6 +12,7 @@ documentation = "https://docs.comfy.org/"
|
||||
|
||||
[tool.ruff]
|
||||
lint.select = [
|
||||
"N805", # invalid-first-argument-name-for-method
|
||||
"S307", # suspicious-eval-usage
|
||||
"S102", # exec
|
||||
"T", # print-usage
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.16.8
|
||||
comfyui-workflow-templates==0.1.1
|
||||
comfyui-frontend-package==1.18.6
|
||||
comfyui-workflow-templates==0.1.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@@ -22,4 +22,5 @@ psutil
|
||||
kornia>=0.7.1
|
||||
spandrel
|
||||
soundfile
|
||||
av
|
||||
av>=14.2.0
|
||||
pydantic~=2.0
|
||||
|
||||
@@ -580,6 +580,9 @@ class PromptServer():
|
||||
info['deprecated'] = True
|
||||
if getattr(obj_class, "EXPERIMENTAL", False):
|
||||
info['experimental'] = True
|
||||
|
||||
if hasattr(obj_class, 'API_NODE'):
|
||||
info['api_node'] = obj_class.API_NODE
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
|
||||
91
tests-unit/comfy_api_test/input_impl_test.py
Normal file
91
tests-unit/comfy_api_test/input_impl_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import io
|
||||
from comfy_api.input_impl.video_types import (
|
||||
container_to_output_format,
|
||||
get_open_write_kwargs,
|
||||
)
|
||||
from comfy_api.util import VideoContainer
|
||||
|
||||
|
||||
def test_container_to_output_format_empty_string():
|
||||
"""Test that an empty string input returns None. `None` arg allows default auto-detection."""
|
||||
assert container_to_output_format("") is None
|
||||
|
||||
|
||||
def test_container_to_output_format_none():
|
||||
"""Test that None input returns None."""
|
||||
assert container_to_output_format(None) is None
|
||||
|
||||
|
||||
def test_container_to_output_format_comma_separated():
|
||||
"""Test that a comma-separated list returns a valid singular format from the list."""
|
||||
comma_separated_format = "mp4,mov,m4a"
|
||||
output_format = container_to_output_format(comma_separated_format)
|
||||
assert output_format in comma_separated_format
|
||||
|
||||
|
||||
def test_container_to_output_format_single():
|
||||
"""Test that a single format string (not comma-separated list) is returned as is."""
|
||||
assert container_to_output_format("mp4") == "mp4"
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_filepath_no_format():
|
||||
"""Test that 'format' kwarg is NOT set when dest is a file path."""
|
||||
kwargs_auto = get_open_write_kwargs("output.mp4", "mp4", VideoContainer.AUTO)
|
||||
assert "format" not in kwargs_auto, "Format should not be set for file paths (AUTO)"
|
||||
|
||||
kwargs_specific = get_open_write_kwargs("output.avi", "mp4", "avi")
|
||||
fail_msg = "Format should not be set for file paths (Specific)"
|
||||
assert "format" not in kwargs_specific, fail_msg
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_base_options_mode():
|
||||
"""Test basic kwargs for file path: mode and movflags."""
|
||||
kwargs = get_open_write_kwargs("output.mp4", "mp4", VideoContainer.AUTO)
|
||||
assert kwargs["mode"] == "w", "mode should be set to write"
|
||||
|
||||
fail_msg = "movflags should be set to preserve custom metadata tags"
|
||||
assert "movflags" in kwargs["options"], fail_msg
|
||||
assert kwargs["options"]["movflags"] == "use_metadata_tags", fail_msg
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_bytesio_auto_format():
|
||||
"""Test kwargs for BytesIO dest with AUTO format."""
|
||||
dest = io.BytesIO()
|
||||
container_fmt = "mov,mp4,m4a"
|
||||
kwargs = get_open_write_kwargs(dest, container_fmt, VideoContainer.AUTO)
|
||||
|
||||
assert kwargs["mode"] == "w"
|
||||
assert kwargs["options"]["movflags"] == "use_metadata_tags"
|
||||
|
||||
fail_msg = (
|
||||
"Format should be a valid format from the container's format list when AUTO"
|
||||
)
|
||||
assert kwargs["format"] in container_fmt, fail_msg
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_bytesio_specific_format():
|
||||
"""Test kwargs for BytesIO dest with a specific single format."""
|
||||
dest = io.BytesIO()
|
||||
container_fmt = "avi"
|
||||
to_fmt = VideoContainer.MP4
|
||||
kwargs = get_open_write_kwargs(dest, container_fmt, to_fmt)
|
||||
|
||||
assert kwargs["mode"] == "w"
|
||||
assert kwargs["options"]["movflags"] == "use_metadata_tags"
|
||||
|
||||
fail_msg = "Format should be the specified format (lowercased) when output format is not AUTO"
|
||||
assert kwargs["format"] == "mp4", fail_msg
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_bytesio_specific_format_list():
|
||||
"""Test kwargs for BytesIO dest with a specific comma-separated format."""
|
||||
dest = io.BytesIO()
|
||||
container_fmt = "avi"
|
||||
to_fmt = "mov,mp4,m4a" # A format string that is a list
|
||||
kwargs = get_open_write_kwargs(dest, container_fmt, to_fmt)
|
||||
|
||||
assert kwargs["mode"] == "w"
|
||||
assert kwargs["options"]["movflags"] == "use_metadata_tags"
|
||||
|
||||
fail_msg = "Format should be a valid format from the specified format list when output format is not AUTO"
|
||||
assert kwargs["format"] in to_fmt, fail_msg
|
||||
@@ -229,3 +229,61 @@ async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
||||
|
||||
async def test_listuserdata_v2_empty_root(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v2/userdata")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == []
|
||||
|
||||
|
||||
async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v2/userdata?path=does_not_exist")
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
(tmp_path / "test_dir" / "file1.txt").write_text("content")
|
||||
(tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v2/userdata?path=test_dir")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
file_paths = {item["path"] for item in data if item["type"] == "file"}
|
||||
assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"}
|
||||
|
||||
|
||||
async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch):
|
||||
# Force backslash as os separator
|
||||
monkeypatch.setattr(os, 'sep', '\\')
|
||||
monkeypatch.setattr(os.path, 'sep', '\\')
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
(tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v2/userdata?path=test_dir")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
for item in data:
|
||||
assert "/" in item["path"]
|
||||
assert "\\" not in item["path"]\
|
||||
|
||||
async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path):
|
||||
# Create a directory with a space in its name and a file inside
|
||||
os.makedirs(tmp_path / "my dir")
|
||||
(tmp_path / "my dir" / "file.txt").write_text("content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
# Use URL-encoded space in path parameter
|
||||
resp = await client.get("/v2/userdata?path=my%20dir&recurse=false")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert len(data) == 1
|
||||
entry = data[0]
|
||||
assert entry["name"] == "file.txt"
|
||||
# Ensure the path is correctly decoded and uses forward slash
|
||||
assert entry["path"] == "my dir/file.txt"
|
||||
|
||||
Reference in New Issue
Block a user