Compare commits
93 Commits
v0.3.10
...
base-path-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17b70728ec | ||
|
|
f3566f0894 | ||
|
|
ca69b41cee | ||
|
|
a058f52090 | ||
|
|
d6bbe8c40f | ||
|
|
a7fe0a94de | ||
|
|
e857dd48b8 | ||
|
|
d303cb5341 | ||
|
|
fb2ad645a3 | ||
|
|
d8a7a32779 | ||
|
|
a00e1489d2 | ||
|
|
ebf038d4fa | ||
|
|
b4de04a1c1 | ||
|
|
b1a02131c9 | ||
|
|
3a3910f91d | ||
|
|
507199d9a8 | ||
|
|
2f3ab40b62 | ||
|
|
7fc3ccdcc2 | ||
|
|
55add50220 | ||
|
|
0aa2368e46 | ||
|
|
cca96a85ae | ||
|
|
619b8cde74 | ||
|
|
31831e6ef1 | ||
|
|
88ceb28e20 | ||
|
|
23289a6a5c | ||
|
|
9d8b6c1f46 | ||
|
|
6320d05696 | ||
|
|
25683b5b02 | ||
|
|
4758fb64b9 | ||
|
|
008761166f | ||
|
|
bfd5dfd611 | ||
|
|
55ade36d01 | ||
|
|
2e20e399ea | ||
|
|
3baf92d120 | ||
|
|
1709a8441e | ||
|
|
cba58fff0b | ||
|
|
2feb8d0b77 | ||
|
|
5b657f8c15 | ||
|
|
2cdbaf5169 | ||
|
|
c78a45685d | ||
|
|
3aaabb12d4 | ||
|
|
1f1c7b7b56 | ||
|
|
90f349f93d | ||
|
|
b9d9bcba14 | ||
|
|
42086af123 | ||
|
|
6c9bd11fa3 | ||
|
|
ee8a7ab69d | ||
|
|
9c773a241b | ||
|
|
adea2beb5c | ||
|
|
2ff3104f70 | ||
|
|
129d8908f7 | ||
|
|
ff838657fa | ||
|
|
2307ff6746 | ||
|
|
d0f3752e33 | ||
|
|
c515bdf371 | ||
|
|
4209edf48d | ||
|
|
d055325783 | ||
|
|
eeab420c70 | ||
|
|
916d1e14a9 | ||
|
|
c496e53519 | ||
|
|
7da85fac3f | ||
|
|
b65b83af6f | ||
|
|
c8a3492c22 | ||
|
|
5cbf79787f | ||
|
|
d45ebb63f6 | ||
|
|
caa6476a69 | ||
|
|
45671cda0b | ||
|
|
8f29664057 | ||
|
|
0b9839ef43 | ||
|
|
953693b137 | ||
|
|
a39ea87bca | ||
|
|
9e9c8a1c64 | ||
|
|
0f11d60afb | ||
|
|
79eea51a1d | ||
|
|
c0338a46a4 | ||
|
|
1c99734e5a | ||
|
|
67758f50f3 | ||
|
|
02eef72bf5 | ||
|
|
b7572b2f87 | ||
|
|
a90aafafc1 | ||
|
|
d9b7cfac7e | ||
|
|
3507870535 | ||
|
|
82ecb02c1e | ||
|
|
a618f768e0 | ||
|
|
e1dec3c792 | ||
|
|
96697c4bc5 | ||
|
|
b504bd606d | ||
|
|
d170292594 | ||
|
|
9cfd185676 | ||
|
|
4b5bcd8ac4 | ||
|
|
ceb50b2cbf | ||
|
|
160ca08138 | ||
|
|
c4bfdba330 |
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -22,7 +22,7 @@ on:
|
|||||||
description: 'Python patch version'
|
description: 'Python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "7"
|
default: "8"
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
|||||||
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
path: "ComfyUI"
|
path: "ComfyUI"
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.8'
|
python-version: '3.9'
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
|||||||
58
.github/workflows/update-frontend.yml
vendored
Normal file
58
.github/workflows/update-frontend.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
name: Update Frontend Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: "Frontend version to update to (e.g., 1.0.0)"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-frontend:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout ComfyUI
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install wait-for-it
|
||||||
|
# Frontend asset will be downloaded to ComfyUI/web_custom_versions/Comfy-Org_ComfyUI_frontend/{version}
|
||||||
|
- name: Start ComfyUI server
|
||||||
|
run: |
|
||||||
|
python main.py --cpu --front-end-version Comfy-Org/ComfyUI_frontend@${{ github.event.inputs.version }} 2>&1 | tee console_output.log &
|
||||||
|
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||||
|
- name: Configure Git
|
||||||
|
run: |
|
||||||
|
git config --global user.name "GitHub Action"
|
||||||
|
git config --global user.email "action@github.com"
|
||||||
|
# Replace existing frontend content with the new version and remove .js.map files
|
||||||
|
# See https://github.com/Comfy-Org/ComfyUI_frontend/issues/2145 for why we remove .js.map files
|
||||||
|
- name: Update frontend content
|
||||||
|
run: |
|
||||||
|
rm -rf web/
|
||||||
|
cp -r web_custom_versions/Comfy-Org_ComfyUI_frontend/${{ github.event.inputs.version }} web/
|
||||||
|
rm web/**/*.js.map
|
||||||
|
- name: Create Pull Request
|
||||||
|
uses: peter-evans/create-pull-request@v7
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.PR_BOT_PAT }}
|
||||||
|
commit-message: "Update frontend to v${{ github.event.inputs.version }}"
|
||||||
|
title: "Frontend Update: v${{ github.event.inputs.version }}"
|
||||||
|
body: |
|
||||||
|
Automated PR to update frontend content to version ${{ github.event.inputs.version }}
|
||||||
|
|
||||||
|
This PR was created automatically by the frontend update workflow.
|
||||||
|
branch: release-${{ github.event.inputs.version }}
|
||||||
|
base: master
|
||||||
|
labels: Frontend,dependencies
|
||||||
58
.github/workflows/update-version.yml
vendored
Normal file
58
.github/workflows/update-version.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
name: Update Version File
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "pyproject.toml"
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-version:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
# Don't run on fork PRs
|
||||||
|
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
permissions:
|
||||||
|
pull-requests: write
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
|
||||||
|
- name: Update comfyui_version.py
|
||||||
|
run: |
|
||||||
|
# Read version from pyproject.toml and update comfyui_version.py
|
||||||
|
python -c '
|
||||||
|
import tomllib
|
||||||
|
|
||||||
|
# Read version from pyproject.toml
|
||||||
|
with open("pyproject.toml", "rb") as f:
|
||||||
|
config = tomllib.load(f)
|
||||||
|
version = config["project"]["version"]
|
||||||
|
|
||||||
|
# Write version to comfyui_version.py
|
||||||
|
with open("comfyui_version.py", "w") as f:
|
||||||
|
f.write("# This file is automatically generated by the build process when version is\n")
|
||||||
|
f.write("# updated in pyproject.toml.\n")
|
||||||
|
f.write(f"__version__ = \"{version}\"\n")
|
||||||
|
'
|
||||||
|
|
||||||
|
- name: Commit changes
|
||||||
|
run: |
|
||||||
|
git config --local user.name "github-actions"
|
||||||
|
git config --local user.email "github-actions@github.com"
|
||||||
|
git fetch origin ${{ github.head_ref }}
|
||||||
|
git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
|
||||||
|
git add comfyui_version.py
|
||||||
|
git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
|
||||||
|
git push origin HEAD:${{ github.head_ref }}
|
||||||
@@ -29,7 +29,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "7"
|
default: "8"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -7,19 +7,19 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "124"
|
default: "126"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "12"
|
default: "13"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "4"
|
default: "1"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "7"
|
default: "8"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -15,9 +15,10 @@
|
|||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
|
||||||
# Frontend assets
|
# Frontend assets
|
||||||
/web/ @huchenlei @webfiltered @pythongosssss
|
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
|
||||||
|
|
||||||
# Extra nodes
|
# Extra nodes
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
|||||||
11
README.md
11
README.md
@@ -52,6 +52,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
|
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
@@ -224,6 +225,16 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
|||||||
|
|
||||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||||
|
|
||||||
|
#### Ascend NPUs
|
||||||
|
|
||||||
|
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
||||||
|
|
||||||
|
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
|
||||||
|
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
|
||||||
|
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
|
||||||
|
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
|
||||||
|
|
||||||
|
|
||||||
# Running
|
# Running
|
||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class AppSettings():
|
class AppSettings():
|
||||||
@@ -11,8 +12,12 @@ class AppSettings():
|
|||||||
file = self.user_manager.get_request_user_filepath(
|
file = self.user_manager.get_request_user_filepath(
|
||||||
request, "comfy.settings.json")
|
request, "comfy.settings.json")
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
with open(file) as f:
|
try:
|
||||||
return json.load(f)
|
with open(file) as f:
|
||||||
|
return json.load(f)
|
||||||
|
except:
|
||||||
|
logging.error(f"The user settings file is corrupted: {file}")
|
||||||
|
return {}
|
||||||
else:
|
else:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|||||||
134
app/custom_node_manager.py
Normal file
134
app/custom_node_manager.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import folder_paths
|
||||||
|
import glob
|
||||||
|
from aiohttp import web
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from utils.json_util import merge_json_recursive
|
||||||
|
|
||||||
|
|
||||||
|
# Extra locale files to load into main.json
|
||||||
|
EXTRA_LOCALE_FILES = [
|
||||||
|
"nodeDefs.json",
|
||||||
|
"commands.json",
|
||||||
|
"settings.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def safe_load_json_file(file_path: str) -> dict:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error(f"Error loading {file_path}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class CustomNodeManager:
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def build_translations(self):
|
||||||
|
"""Load all custom nodes translations during initialization. Translations are
|
||||||
|
expected to be loaded from `locales/` folder.
|
||||||
|
|
||||||
|
The folder structure is expected to be the following:
|
||||||
|
- custom_nodes/
|
||||||
|
- custom_node_1/
|
||||||
|
- locales/
|
||||||
|
- en/
|
||||||
|
- main.json
|
||||||
|
- commands.json
|
||||||
|
- settings.json
|
||||||
|
|
||||||
|
returned translations are expected to be in the following format:
|
||||||
|
{
|
||||||
|
"en": {
|
||||||
|
"nodeDefs": {...},
|
||||||
|
"commands": {...},
|
||||||
|
"settings": {...},
|
||||||
|
...{other main.json keys}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
translations = {}
|
||||||
|
|
||||||
|
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||||
|
# Sort glob results for deterministic ordering
|
||||||
|
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
|
||||||
|
locales_dir = os.path.join(custom_node_dir, "locales")
|
||||||
|
if not os.path.exists(locales_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
|
||||||
|
lang_code = os.path.basename(os.path.dirname(lang_dir))
|
||||||
|
|
||||||
|
if lang_code not in translations:
|
||||||
|
translations[lang_code] = {}
|
||||||
|
|
||||||
|
# Load main.json
|
||||||
|
main_file = os.path.join(lang_dir, "main.json")
|
||||||
|
node_translations = safe_load_json_file(main_file)
|
||||||
|
|
||||||
|
# Load extra locale files
|
||||||
|
for extra_file in EXTRA_LOCALE_FILES:
|
||||||
|
extra_file_path = os.path.join(lang_dir, extra_file)
|
||||||
|
key = extra_file.split(".")[0]
|
||||||
|
json_data = safe_load_json_file(extra_file_path)
|
||||||
|
if json_data:
|
||||||
|
node_translations[key] = json_data
|
||||||
|
|
||||||
|
if node_translations:
|
||||||
|
translations[lang_code] = merge_json_recursive(
|
||||||
|
translations[lang_code], node_translations
|
||||||
|
)
|
||||||
|
|
||||||
|
return translations
|
||||||
|
|
||||||
|
def add_routes(self, routes, webapp, loadedModules):
|
||||||
|
|
||||||
|
@routes.get("/workflow_templates")
|
||||||
|
async def get_workflow_templates(request):
|
||||||
|
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
||||||
|
files = [
|
||||||
|
file
|
||||||
|
for folder in folder_paths.get_folder_paths("custom_nodes")
|
||||||
|
for file in glob.glob(
|
||||||
|
os.path.join(folder, "*/example_workflows/*.json")
|
||||||
|
)
|
||||||
|
]
|
||||||
|
workflow_templates_dict = (
|
||||||
|
{}
|
||||||
|
) # custom_nodes folder name -> example workflow names
|
||||||
|
for file in files:
|
||||||
|
custom_nodes_name = os.path.basename(
|
||||||
|
os.path.dirname(os.path.dirname(file))
|
||||||
|
)
|
||||||
|
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
||||||
|
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
|
||||||
|
workflow_name
|
||||||
|
)
|
||||||
|
return web.json_response(workflow_templates_dict)
|
||||||
|
|
||||||
|
# Serve workflow templates from custom nodes.
|
||||||
|
for module_name, module_dir in loadedModules:
|
||||||
|
workflows_dir = os.path.join(module_dir, "example_workflows")
|
||||||
|
if os.path.exists(workflows_dir):
|
||||||
|
webapp.add_routes(
|
||||||
|
[
|
||||||
|
web.static(
|
||||||
|
"/api/workflow_templates/" + module_name, workflows_dir
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@routes.get("/i18n")
|
||||||
|
async def get_i18n(request):
|
||||||
|
"""Returns translations from all custom nodes' locales folders."""
|
||||||
|
return web.json_response(self.build_translations())
|
||||||
@@ -51,7 +51,7 @@ def on_flush(callback):
|
|||||||
if stderr_interceptor is not None:
|
if stderr_interceptor is not None:
|
||||||
stderr_interceptor.on_flush(callback)
|
stderr_interceptor.on_flush(callback)
|
||||||
|
|
||||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
|
||||||
global logs
|
global logs
|
||||||
if logs:
|
if logs:
|
||||||
return
|
return
|
||||||
@@ -70,4 +70,15 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
|||||||
|
|
||||||
stream_handler = logging.StreamHandler()
|
stream_handler = logging.StreamHandler()
|
||||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
|
||||||
|
if use_stdout:
|
||||||
|
# Only errors and critical to stderr
|
||||||
|
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
|
||||||
|
|
||||||
|
# Lesser to stdout
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
|||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
@@ -141,6 +141,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
|||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
|
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ This module provides type hinting and concrete convenience types for node develo
|
|||||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||||
|
|
||||||
class ExampleNode(ComfyNodeABC):
|
class ExampleNode(ComfyNodeABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
|
|
||||||
|
|
||||||
class ExampleNode(ComfyNodeABC):
|
class ExampleNode(ComfyNodeABC):
|
||||||
"""An example node that just adds 1 to an input integer.
|
"""An example node that just adds 1 to an input integer.
|
||||||
|
|
||||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
|
||||||
* Not intended for use in ComfyUI.
|
* This node is intended as an example for developers only.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DESCRIPTION = cleandoc(__doc__)
|
DESCRIPTION = cleandoc(__doc__)
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||||
return False
|
return False
|
||||||
|
|
||||||
mult_min = lcm(s1[1], s2[1])
|
mult_min = math.lcm(s1[1], s2[1])
|
||||||
diff = mult_min // min(s1[1], s2[1])
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
return False
|
return False
|
||||||
@@ -57,7 +57,7 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
crossattn_max_len = self.cond.shape[1]
|
crossattn_max_len = self.cond.shape[1]
|
||||||
for x in others:
|
for x in others:
|
||||||
c = x.cond
|
c = x.cond
|
||||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
|
||||||
conds.append(c)
|
conds.append(c)
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
|
|||||||
@@ -661,7 +661,7 @@ class UniPC:
|
|||||||
|
|
||||||
if x_t is None:
|
if x_t is None:
|
||||||
if use_predictor:
|
if use_predictor:
|
||||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||||
else:
|
else:
|
||||||
pred_res = 0
|
pred_res = 0
|
||||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||||
@@ -669,7 +669,7 @@ class UniPC:
|
|||||||
if use_corrector:
|
if use_corrector:
|
||||||
model_t = self.model_fn(x_t, t)
|
model_t = self.model_fn(x_t, t)
|
||||||
if D1s is not None:
|
if D1s is not None:
|
||||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||||
else:
|
else:
|
||||||
corr_res = 0
|
corr_res = 0
|
||||||
D1_t = (model_t - model_prev_0)
|
D1_t = (model_t - model_prev_0)
|
||||||
|
|||||||
416
comfy/hooks.py
416
comfy/hooks.py
@@ -16,91 +16,132 @@ import comfy.model_management
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from node_helpers import conditioning_set_values
|
from node_helpers import conditioning_set_values
|
||||||
|
|
||||||
|
# #######################################################################################################
|
||||||
|
# Hooks explanation
|
||||||
|
# -------------------
|
||||||
|
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
|
||||||
|
# make explicit special cases like it does for ControlNet and GLIGEN.
|
||||||
|
#
|
||||||
|
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
|
||||||
|
# that should run special code when a 'marked' cond is used in sampling.
|
||||||
|
# #######################################################################################################
|
||||||
|
|
||||||
class EnumHookMode(enum.Enum):
|
class EnumHookMode(enum.Enum):
|
||||||
|
'''
|
||||||
|
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
|
||||||
|
|
||||||
|
MinVram: No caching will occur for any operations related to hooks.
|
||||||
|
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
|
||||||
|
'''
|
||||||
MinVram = "minvram"
|
MinVram = "minvram"
|
||||||
MaxSpeed = "maxspeed"
|
MaxSpeed = "maxspeed"
|
||||||
|
|
||||||
class EnumHookType(enum.Enum):
|
class EnumHookType(enum.Enum):
|
||||||
|
'''
|
||||||
|
Hook types, each of which has different expected behavior.
|
||||||
|
'''
|
||||||
Weight = "weight"
|
Weight = "weight"
|
||||||
Patch = "patch"
|
|
||||||
ObjectPatch = "object_patch"
|
ObjectPatch = "object_patch"
|
||||||
AddModels = "add_models"
|
AdditionalModels = "add_models"
|
||||||
Callbacks = "callbacks"
|
TransformerOptions = "transformer_options"
|
||||||
Wrappers = "wrappers"
|
Injections = "add_injections"
|
||||||
SetInjections = "add_injections"
|
|
||||||
|
|
||||||
class EnumWeightTarget(enum.Enum):
|
class EnumWeightTarget(enum.Enum):
|
||||||
Model = "model"
|
Model = "model"
|
||||||
Clip = "clip"
|
Clip = "clip"
|
||||||
|
|
||||||
|
class EnumHookScope(enum.Enum):
|
||||||
|
'''
|
||||||
|
Determines if hook should be limited in its influence over sampling.
|
||||||
|
|
||||||
|
AllConditioning: hook will affect all conds used in sampling.
|
||||||
|
HookedOnly: hook will only affect the conds it was attached to.
|
||||||
|
'''
|
||||||
|
AllConditioning = "all_conditioning"
|
||||||
|
HookedOnly = "hooked_only"
|
||||||
|
|
||||||
|
|
||||||
class _HookRef:
|
class _HookRef:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# NOTE: this is an example of how the should_register function should look
|
|
||||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
|
'''Example for how custom_should_register function can look like.'''
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
|
||||||
|
'''Creates base dictionary for use with Hooks' target param.'''
|
||||||
|
d = {}
|
||||||
|
if target is not None:
|
||||||
|
d['target'] = target
|
||||||
|
d.update(kwargs)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
class Hook:
|
class Hook:
|
||||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
||||||
self.hook_type = hook_type
|
self.hook_type = hook_type
|
||||||
|
'''Enum identifying the general class of this hook.'''
|
||||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||||
|
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
|
||||||
self.hook_id = hook_id
|
self.hook_id = hook_id
|
||||||
|
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
|
||||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||||
|
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||||
self.custom_should_register = default_should_register
|
self.custom_should_register = default_should_register
|
||||||
self.auto_apply_to_nonpositive = False
|
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength(self):
|
def strength(self):
|
||||||
return self.hook_keyframe.strength
|
return self.hook_keyframe.strength
|
||||||
|
|
||||||
def initialize_timesteps(self, model: 'BaseModel'):
|
def initialize_timesteps(self, model: BaseModel):
|
||||||
self.reset()
|
self.reset()
|
||||||
self.hook_keyframe.initialize_timesteps(model)
|
self.hook_keyframe.initialize_timesteps(model)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.hook_keyframe.reset()
|
self.hook_keyframe.reset()
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: Hook = self.__class__()
|
||||||
subtype = type(self)
|
|
||||||
c: Hook = subtype()
|
|
||||||
c.hook_type = self.hook_type
|
c.hook_type = self.hook_type
|
||||||
c.hook_ref = self.hook_ref
|
c.hook_ref = self.hook_ref
|
||||||
c.hook_id = self.hook_id
|
c.hook_id = self.hook_id
|
||||||
c.hook_keyframe = self.hook_keyframe
|
c.hook_keyframe = self.hook_keyframe
|
||||||
|
c.hook_scope = self.hook_scope
|
||||||
c.custom_should_register = self.custom_should_register
|
c.custom_should_register = self.custom_should_register
|
||||||
# TODO: make this do something
|
|
||||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
return self.custom_should_register(self, model, model_options, target, registered)
|
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||||
|
|
||||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
def __eq__(self, other: Hook):
|
||||||
pass
|
|
||||||
|
|
||||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __eq__(self, other: 'Hook'):
|
|
||||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.hook_ref)
|
return hash(self.hook_ref)
|
||||||
|
|
||||||
class WeightHook(Hook):
|
class WeightHook(Hook):
|
||||||
|
'''
|
||||||
|
Hook responsible for tracking weights to be applied to some model/clip.
|
||||||
|
|
||||||
|
Note, value of hook_scope is ignored and is treated as HookedOnly.
|
||||||
|
'''
|
||||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||||
super().__init__(hook_type=EnumHookType.Weight)
|
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
|
||||||
self.weights: dict = None
|
self.weights: dict = None
|
||||||
self.weights_clip: dict = None
|
self.weights_clip: dict = None
|
||||||
self.need_weight_init = True
|
self.need_weight_init = True
|
||||||
self._strength_model = strength_model
|
self._strength_model = strength_model
|
||||||
self._strength_clip = strength_clip
|
self._strength_clip = strength_clip
|
||||||
|
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength_model(self):
|
def strength_model(self):
|
||||||
@@ -110,36 +151,36 @@ class WeightHook(Hook):
|
|||||||
def strength_clip(self):
|
def strength_clip(self):
|
||||||
return self._strength_clip * self.strength
|
return self._strength_clip * self.strength
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
weights = None
|
weights = None
|
||||||
if target == EnumWeightTarget.Model:
|
|
||||||
strength = self._strength_model
|
target = target_dict.get('target', None)
|
||||||
else:
|
if target == EnumWeightTarget.Clip:
|
||||||
strength = self._strength_clip
|
strength = self._strength_clip
|
||||||
|
else:
|
||||||
|
strength = self._strength_model
|
||||||
|
|
||||||
if self.need_weight_init:
|
if self.need_weight_init:
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if target == EnumWeightTarget.Model:
|
if target == EnumWeightTarget.Clip:
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
|
||||||
else:
|
|
||||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||||
|
else:
|
||||||
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||||
else:
|
else:
|
||||||
if target == EnumWeightTarget.Model:
|
if target == EnumWeightTarget.Clip:
|
||||||
weights = self.weights
|
|
||||||
else:
|
|
||||||
weights = self.weights_clip
|
weights = self.weights_clip
|
||||||
|
else:
|
||||||
|
weights = self.weights
|
||||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||||
registered.append(self)
|
registered.add(self)
|
||||||
return True
|
return True
|
||||||
# TODO: add logs about any keys that were not applied
|
# TODO: add logs about any keys that were not applied
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: WeightHook = super().clone()
|
||||||
subtype = type(self)
|
|
||||||
c: WeightHook = super().clone(subtype)
|
|
||||||
c.weights = self.weights
|
c.weights = self.weights
|
||||||
c.weights_clip = self.weights_clip
|
c.weights_clip = self.weights_clip
|
||||||
c.need_weight_init = self.need_weight_init
|
c.need_weight_init = self.need_weight_init
|
||||||
@@ -147,127 +188,158 @@ class WeightHook(Hook):
|
|||||||
c._strength_clip = self._strength_clip
|
c._strength_clip = self._strength_clip
|
||||||
return c
|
return c
|
||||||
|
|
||||||
class PatchHook(Hook):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(hook_type=EnumHookType.Patch)
|
|
||||||
self.patches: dict = None
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: PatchHook = super().clone(subtype)
|
|
||||||
c.patches = self.patches
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class ObjectPatchHook(Hook):
|
class ObjectPatchHook(Hook):
|
||||||
def __init__(self):
|
def __init__(self, object_patches: dict[str]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||||
self.object_patches: dict = None
|
self.object_patches = object_patches
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: ObjectPatchHook = super().clone()
|
||||||
subtype = type(self)
|
|
||||||
c: ObjectPatchHook = super().clone(subtype)
|
|
||||||
c.object_patches = self.object_patches
|
c.object_patches = self.object_patches
|
||||||
return c
|
return c
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class AddModelsHook(Hook):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
|
||||||
super().__init__(hook_type=EnumHookType.AddModels)
|
|
||||||
self.key = key
|
class AdditionalModelsHook(Hook):
|
||||||
|
'''
|
||||||
|
Hook responsible for telling model management any additional models that should be loaded.
|
||||||
|
|
||||||
|
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
||||||
|
'''
|
||||||
|
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
||||||
|
super().__init__(hook_type=EnumHookType.AdditionalModels)
|
||||||
self.models = models
|
self.models = models
|
||||||
self.append_when_same = True
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: AddModelsHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
|
||||||
c.models = self.models.copy() if self.models else self.models
|
|
||||||
c.append_when_same = self.append_when_same
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class CallbackHook(Hook):
|
|
||||||
def __init__(self, key: str=None, callback: Callable=None):
|
|
||||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
|
||||||
self.key = key
|
self.key = key
|
||||||
self.callback = callback
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: AdditionalModelsHook = super().clone()
|
||||||
subtype = type(self)
|
c.models = self.models.copy() if self.models else self.models
|
||||||
c: CallbackHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
c.key = self.key
|
||||||
c.callback = self.callback
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class WrapperHook(Hook):
|
|
||||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
|
||||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
|
||||||
self.wrappers_dict = wrappers_dict
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: WrapperHook = super().clone(subtype)
|
|
||||||
c.wrappers_dict = self.wrappers_dict
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
registered.add(self)
|
||||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
|
||||||
registered.append(self)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class SetInjectionsHook(Hook):
|
class TransformerOptionsHook(Hook):
|
||||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
'''
|
||||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||||
|
'''
|
||||||
|
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
|
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||||
|
self.transformers_dict = transformers_dict
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
self._skip_adding = False
|
||||||
|
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c: TransformerOptionsHook = super().clone()
|
||||||
|
c.transformers_dict = self.transformers_dict
|
||||||
|
c._skip_adding = self._skip_adding
|
||||||
|
return c
|
||||||
|
|
||||||
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
|
return False
|
||||||
|
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||||
|
self._skip_adding = False
|
||||||
|
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||||
|
add_model_options = {"transformer_options": self.transformers_dict,
|
||||||
|
"to_load_options": self.transformers_dict}
|
||||||
|
# skip_adding if included in AllConditioning to avoid double loading
|
||||||
|
self._skip_adding = True
|
||||||
|
else:
|
||||||
|
add_model_options = {"to_load_options": self.transformers_dict}
|
||||||
|
registered.add(self)
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||||
|
if not self._skip_adding:
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||||
|
|
||||||
|
WrapperHook = TransformerOptionsHook
|
||||||
|
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||||
|
|
||||||
|
class InjectionsHook(Hook):
|
||||||
|
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
|
super().__init__(hook_type=EnumHookType.Injections)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.injections = injections
|
self.injections = injections
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: InjectionsHook = super().clone()
|
||||||
subtype = type(self)
|
|
||||||
c: SetInjectionsHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
c.key = self.key
|
||||||
c.injections = self.injections.copy() if self.injections else self.injections
|
c.injections = self.injections.copy() if self.injections else self.injections
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
# TODO: add functionality
|
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
|
||||||
pass
|
|
||||||
|
|
||||||
class HookGroup:
|
class HookGroup:
|
||||||
|
'''
|
||||||
|
Stores groups of hooks, and allows them to be queried by type.
|
||||||
|
|
||||||
|
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
|
||||||
|
always use the provided functions on HookGroup.
|
||||||
|
'''
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hooks: list[Hook] = []
|
self.hooks: list[Hook] = []
|
||||||
|
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.hooks)
|
||||||
|
|
||||||
def add(self, hook: Hook):
|
def add(self, hook: Hook):
|
||||||
if hook not in self.hooks:
|
if hook not in self.hooks:
|
||||||
self.hooks.append(hook)
|
self.hooks.append(hook)
|
||||||
|
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
|
||||||
|
|
||||||
|
def remove(self, hook: Hook):
|
||||||
|
if hook in self.hooks:
|
||||||
|
self.hooks.remove(hook)
|
||||||
|
self._hook_dict[hook.hook_type].remove(hook)
|
||||||
|
|
||||||
|
def get_type(self, hook_type: EnumHookType):
|
||||||
|
return self._hook_dict.get(hook_type, [])
|
||||||
|
|
||||||
def contains(self, hook: Hook):
|
def contains(self, hook: Hook):
|
||||||
return hook in self.hooks
|
return hook in self.hooks
|
||||||
|
|
||||||
|
def is_subset_of(self, other: HookGroup):
|
||||||
|
self_hooks = set(self.hooks)
|
||||||
|
other_hooks = set(other.hooks)
|
||||||
|
return self_hooks.issubset(other_hooks)
|
||||||
|
|
||||||
|
def new_with_common_hooks(self, other: HookGroup):
|
||||||
|
c = HookGroup()
|
||||||
|
for hook in self.hooks:
|
||||||
|
if other.contains(hook):
|
||||||
|
c.add(hook.clone())
|
||||||
|
return c
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookGroup()
|
c = HookGroup()
|
||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
c.add(hook.clone())
|
c.add(hook.clone())
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def clone_and_combine(self, other: 'HookGroup'):
|
def clone_and_combine(self, other: HookGroup):
|
||||||
c = self.clone()
|
c = self.clone()
|
||||||
if other is not None:
|
if other is not None:
|
||||||
for hook in other.hooks:
|
for hook in other.hooks:
|
||||||
c.add(hook.clone())
|
c.add(hook.clone())
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
|
||||||
if hook_kf is None:
|
if hook_kf is None:
|
||||||
hook_kf = HookKeyframeGroup()
|
hook_kf = HookKeyframeGroup()
|
||||||
else:
|
else:
|
||||||
@@ -275,36 +347,29 @@ class HookGroup:
|
|||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
hook.hook_keyframe = hook_kf
|
hook.hook_keyframe = hook_kf
|
||||||
|
|
||||||
def get_dict_repr(self):
|
|
||||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
|
||||||
for hook in self.hooks:
|
|
||||||
with_type = d.setdefault(hook.hook_type, {})
|
|
||||||
with_type[hook] = None
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_hooks_for_clip_schedule(self):
|
def get_hooks_for_clip_schedule(self):
|
||||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||||
for hook in self.hooks:
|
# only care about WeightHooks, for now
|
||||||
# only care about WeightHooks, for now
|
for hook in self.get_type(EnumHookType.Weight):
|
||||||
if hook.hook_type == EnumHookType.Weight:
|
hook: WeightHook
|
||||||
hook_schedule = []
|
hook_schedule = []
|
||||||
# if no hook keyframes, assign default value
|
# if no hook keyframes, assign default value
|
||||||
if len(hook.hook_keyframe.keyframes) == 0:
|
if len(hook.hook_keyframe.keyframes) == 0:
|
||||||
hook_schedule.append(((0.0, 1.0), None))
|
hook_schedule.append(((0.0, 1.0), None))
|
||||||
scheduled_hooks[hook] = hook_schedule
|
|
||||||
continue
|
|
||||||
# find ranges of values
|
|
||||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
|
||||||
for keyframe in hook.hook_keyframe.keyframes:
|
|
||||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
# create final range, assuming last start_percent was not 1.0
|
|
||||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
|
||||||
scheduled_hooks[hook] = hook_schedule
|
scheduled_hooks[hook] = hook_schedule
|
||||||
|
continue
|
||||||
|
# find ranges of values
|
||||||
|
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||||
|
for keyframe in hook.hook_keyframe.keyframes:
|
||||||
|
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
# create final range, assuming last start_percent was not 1.0
|
||||||
|
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||||
|
scheduled_hooks[hook] = hook_schedule
|
||||||
# hooks should not have their schedules in a list of tuples
|
# hooks should not have their schedules in a list of tuples
|
||||||
all_ranges: list[tuple[float, float]] = []
|
all_ranges: list[tuple[float, float]] = []
|
||||||
for range_kfs in scheduled_hooks.values():
|
for range_kfs in scheduled_hooks.values():
|
||||||
@@ -336,7 +401,7 @@ class HookGroup:
|
|||||||
hook.reset()
|
hook.reset()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
|
||||||
actual: list[HookGroup] = []
|
actual: list[HookGroup] = []
|
||||||
for group in hooks_list:
|
for group in hooks_list:
|
||||||
if group is not None:
|
if group is not None:
|
||||||
@@ -366,9 +431,15 @@ class HookKeyframe:
|
|||||||
self.start_t = 999999999.9
|
self.start_t = 999999999.9
|
||||||
self.guarantee_steps = guarantee_steps
|
self.guarantee_steps = guarantee_steps
|
||||||
|
|
||||||
|
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
|
||||||
|
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
|
||||||
|
if self.start_t > max_sigma:
|
||||||
|
return 0
|
||||||
|
return self.guarantee_steps
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookKeyframe(strength=self.strength,
|
c = HookKeyframe(strength=self.strength,
|
||||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||||
c.start_t = self.start_t
|
c.start_t = self.start_t
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -408,6 +479,12 @@ class HookKeyframeGroup:
|
|||||||
else:
|
else:
|
||||||
self._current_keyframe = None
|
self._current_keyframe = None
|
||||||
|
|
||||||
|
def has_guarantee_steps(self):
|
||||||
|
for kf in self.keyframes:
|
||||||
|
if kf.guarantee_steps > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def has_index(self, index: int):
|
def has_index(self, index: int):
|
||||||
return index >= 0 and index < len(self.keyframes)
|
return index >= 0 and index < len(self.keyframes)
|
||||||
|
|
||||||
@@ -421,19 +498,20 @@ class HookKeyframeGroup:
|
|||||||
c._set_first_as_current()
|
c._set_first_as_current()
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def initialize_timesteps(self, model: 'BaseModel'):
|
def initialize_timesteps(self, model: BaseModel):
|
||||||
for keyframe in self.keyframes:
|
for keyframe in self.keyframes:
|
||||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||||
|
|
||||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
|
||||||
if self.is_empty():
|
if self.is_empty():
|
||||||
return False
|
return False
|
||||||
if curr_t == self._curr_t:
|
if curr_t == self._curr_t:
|
||||||
return False
|
return False
|
||||||
|
max_sigma = torch.max(transformer_options["sample_sigmas"])
|
||||||
prev_index = self._current_index
|
prev_index = self._current_index
|
||||||
prev_strength = self._current_strength
|
prev_strength = self._current_strength
|
||||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
|
||||||
# if has next index, loop through and see if need to switch
|
# if has next index, loop through and see if need to switch
|
||||||
if self.has_index(self._current_index+1):
|
if self.has_index(self._current_index+1):
|
||||||
for i in range(self._current_index+1, len(self.keyframes)):
|
for i in range(self._current_index+1, len(self.keyframes)):
|
||||||
@@ -446,7 +524,7 @@ class HookKeyframeGroup:
|
|||||||
self._current_keyframe = eval_c
|
self._current_keyframe = eval_c
|
||||||
self._current_used_steps = 0
|
self._current_used_steps = 0
|
||||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||||
if self._current_keyframe.guarantee_steps > 0:
|
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||||
break
|
break
|
||||||
# if eval_c is outside the percent range, stop looking further
|
# if eval_c is outside the percent range, stop looking further
|
||||||
else: break
|
else: break
|
||||||
@@ -509,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
|||||||
sorted_list.extend(object_list)
|
sorted_list.extend(object_list)
|
||||||
return sorted_list
|
return sorted_list
|
||||||
|
|
||||||
|
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
||||||
|
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
||||||
|
if hooks is None or model.is_clip:
|
||||||
|
return {}
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
|
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
||||||
|
hook: TransformerOptionsHook
|
||||||
|
hook.on_apply_hooks(model, transformer_options)
|
||||||
|
return transformer_options
|
||||||
|
|
||||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||||
hook_group = HookGroup()
|
hook_group = HookGroup()
|
||||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||||
@@ -535,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
|
|||||||
hook.need_weight_init = False
|
hook.need_weight_init = False
|
||||||
return hook_group
|
return hook_group
|
||||||
|
|
||||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
|
||||||
if model is None:
|
if model is None:
|
||||||
return None
|
return None
|
||||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||||
@@ -547,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
|
|||||||
return patches_model
|
return patches_model
|
||||||
|
|
||||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
|
||||||
strength_model: float, strength_clip: float):
|
strength_model: float, strength_clip: float):
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if model is not None:
|
if model is not None:
|
||||||
@@ -599,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
|
|||||||
else:
|
else:
|
||||||
c_dict[hooks_key] = cache[hooks_tuple]
|
c_dict[hooks_key] = cache[hooks_tuple]
|
||||||
|
|
||||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
|
||||||
|
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||||
c = []
|
c = []
|
||||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
if cache is None:
|
||||||
|
cache = {}
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
for k in values:
|
for k in values:
|
||||||
if append_hooks and k == 'hooks':
|
if append_hooks and k == 'hooks':
|
||||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
_combine_hooks_from_values(n[1], values, cache)
|
||||||
else:
|
else:
|
||||||
n[1][k] = values[k]
|
n[1][k] = values[k]
|
||||||
c.append(n)
|
c.append(n)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||||
if hooks is None:
|
if hooks is None:
|
||||||
return cond
|
return cond
|
||||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
|
||||||
|
|
||||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||||
if timestep_range is None:
|
if timestep_range is None:
|
||||||
@@ -651,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
|
|||||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
final_conds = []
|
final_conds = []
|
||||||
|
cache = {}
|
||||||
for c in conds:
|
for c in conds:
|
||||||
# first, apply lora_hook to conditioning, if provided
|
# first, apply lora_hook to conditioning, if provided
|
||||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, apply mask to conditioning
|
# next, apply mask to conditioning
|
||||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
@@ -665,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
|||||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
combined_conds = []
|
combined_conds = []
|
||||||
|
cache = {}
|
||||||
for c, masked_c in zip(conds, new_conds):
|
for c, masked_c in zip(conds, new_conds):
|
||||||
# first, apply lora_hook to new conditioning, if provided
|
# first, apply lora_hook to new conditioning, if provided
|
||||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, apply mask to new conditioning, if provided
|
# next, apply mask to new conditioning, if provided
|
||||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
@@ -679,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
|
|||||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
combined_conds = []
|
combined_conds = []
|
||||||
|
cache = {}
|
||||||
for c, new_c in zip(conds, new_conds):
|
for c, new_c in zip(conds, new_conds):
|
||||||
# first, apply lora_hook to new conditioning, if provided
|
# first, apply lora_hook to new conditioning, if provided
|
||||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||||
new_c = conditioning_set_values(new_c, {'default': True})
|
new_c = conditioning_set_values(new_c, {'default': True})
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
|||||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||||
"""Constructs a continuous VP noise schedule."""
|
"""Constructs a continuous VP noise schedule."""
|
||||||
t = torch.linspace(1, eps_s, n, device=device)
|
t = torch.linspace(1, eps_s, n, device=device)
|
||||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
|
||||||
return append_zero(sigmas)
|
return append_zero(sigmas)
|
||||||
|
|
||||||
|
|
||||||
@@ -70,8 +70,14 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
|||||||
return sigma_down, sigma_up
|
return sigma_down, sigma_up
|
||||||
|
|
||||||
|
|
||||||
def default_noise_sampler(x):
|
def default_noise_sampler(x, seed=None):
|
||||||
return lambda sigma, sigma_next: torch.randn_like(x)
|
if seed is not None:
|
||||||
|
generator = torch.Generator(device=x.device)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
class BatchedBrownianTree:
|
class BatchedBrownianTree:
|
||||||
@@ -168,7 +174,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -189,7 +196,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -290,7 +298,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
|
|
||||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -318,7 +327,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -465,7 +475,7 @@ class DPMSolver(nn.Module):
|
|||||||
return x_3, eps_cache
|
return x_3, eps_cache
|
||||||
|
|
||||||
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||||
if not t_end > t_start and eta:
|
if not t_end > t_start and eta:
|
||||||
raise ValueError('eta must be 0 for reverse sampling')
|
raise ValueError('eta must be 0 for reverse sampling')
|
||||||
|
|
||||||
@@ -504,7 +514,7 @@ class DPMSolver(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||||
if order not in {2, 3}:
|
if order not in {2, 3}:
|
||||||
raise ValueError('order should be 2 or 3')
|
raise ValueError('order should be 2 or 3')
|
||||||
forward = t_end > t_start
|
forward = t_end > t_start
|
||||||
@@ -591,7 +601,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
|
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
t_fn = lambda sigma: sigma.log().neg()
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
@@ -625,7 +636,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||||
@@ -882,7 +894,8 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
|||||||
|
|
||||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
@@ -902,7 +915,8 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@@ -1153,7 +1167,8 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
temp = [0]
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
@@ -1179,7 +1194,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
temp = [0]
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
@@ -1249,3 +1265,97 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||||
old_uncond_denoised = uncond_denoised
|
old_uncond_denoised = uncond_denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
phi1_fn = lambda t: torch.expm1(t) / t
|
||||||
|
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||||
|
|
||||||
|
old_denoised = None
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
if cfg_pp:
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
if s_churn > 0:
|
||||||
|
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
else:
|
||||||
|
gamma = 0
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
|
if gamma > 0:
|
||||||
|
eps = torch.randn_like(x) * s_noise
|
||||||
|
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
||||||
|
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
|
||||||
|
if sigmas[i + 1] == 0 or old_denoised is None:
|
||||||
|
# Euler method
|
||||||
|
if cfg_pp:
|
||||||
|
d = to_d(x, sigma_hat, uncond_denoised)
|
||||||
|
x = denoised + d * sigmas[i + 1]
|
||||||
|
else:
|
||||||
|
d = to_d(x, sigma_hat, denoised)
|
||||||
|
dt = sigmas[i + 1] - sigma_hat
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||||
|
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
|
||||||
|
h = t_next - t
|
||||||
|
c2 = (t_prev - t) / h
|
||||||
|
|
||||||
|
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||||
|
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
|
||||||
|
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
|
||||||
|
|
||||||
|
if cfg_pp:
|
||||||
|
x = x + (denoised - uncond_denoised)
|
||||||
|
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||||
|
|
||||||
|
old_denoised = denoised
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||||
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||||
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||||
|
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
old_d = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
d = to_d(x, sigmas[i], denoised)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
dt = sigmas[i + 1] - sigmas[i]
|
||||||
|
if i == 0:
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# Gradient estimation
|
||||||
|
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||||
|
x = x + d_bar * dt
|
||||||
|
old_d = d
|
||||||
|
return x
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import torch
|
|||||||
class LatentFormat:
|
class LatentFormat:
|
||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
|
latent_dimensions = 2
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
latent_rgb_factors_bias = None
|
latent_rgb_factors_bias = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
@@ -143,6 +144,7 @@ class SD3(LatentFormat):
|
|||||||
|
|
||||||
class StableAudio1(LatentFormat):
|
class StableAudio1(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
|
latent_dimensions = 1
|
||||||
|
|
||||||
class Flux(SD3):
|
class Flux(SD3):
|
||||||
latent_channels = 16
|
latent_channels = 16
|
||||||
@@ -178,6 +180,7 @@ class Flux(SD3):
|
|||||||
|
|
||||||
class Mochi(LatentFormat):
|
class Mochi(LatentFormat):
|
||||||
latent_channels = 12
|
latent_channels = 12
|
||||||
|
latent_dimensions = 3
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 1.0
|
self.scale_factor = 1.0
|
||||||
@@ -219,6 +222,8 @@ class Mochi(LatentFormat):
|
|||||||
|
|
||||||
class LTXV(LatentFormat):
|
class LTXV(LatentFormat):
|
||||||
latent_channels = 128
|
latent_channels = 128
|
||||||
|
latent_dimensions = 3
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||||
@@ -355,6 +360,7 @@ class LTXV(LatentFormat):
|
|||||||
|
|
||||||
class HunyuanVideo(LatentFormat):
|
class HunyuanVideo(LatentFormat):
|
||||||
latent_channels = 16
|
latent_channels = 16
|
||||||
|
latent_dimensions = 3
|
||||||
scale_factor = 0.476986
|
scale_factor = 0.476986
|
||||||
latent_rgb_factors = [
|
latent_rgb_factors = [
|
||||||
[-0.0395, -0.0331, 0.0445],
|
[-0.0395, -0.0331, 0.0445],
|
||||||
@@ -376,3 +382,28 @@ class HunyuanVideo(LatentFormat):
|
|||||||
]
|
]
|
||||||
|
|
||||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||||
|
|
||||||
|
class Cosmos1CV8x8x8(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
latent_dimensions = 3
|
||||||
|
|
||||||
|
latent_rgb_factors = [
|
||||||
|
[ 0.1817, 0.2284, 0.2423],
|
||||||
|
[-0.0586, -0.0862, -0.3108],
|
||||||
|
[-0.4703, -0.4255, -0.3995],
|
||||||
|
[ 0.0803, 0.1963, 0.1001],
|
||||||
|
[-0.0820, -0.1050, 0.0400],
|
||||||
|
[ 0.2511, 0.3098, 0.2787],
|
||||||
|
[-0.1830, -0.2117, -0.0040],
|
||||||
|
[-0.0621, -0.2187, -0.0939],
|
||||||
|
[ 0.3619, 0.1082, 0.1455],
|
||||||
|
[ 0.3164, 0.3922, 0.2575],
|
||||||
|
[ 0.1152, 0.0231, -0.0462],
|
||||||
|
[-0.1434, -0.3609, -0.3665],
|
||||||
|
[ 0.0635, 0.1471, 0.1680],
|
||||||
|
[-0.3635, -0.1963, -0.3248],
|
||||||
|
[-0.1865, 0.0365, 0.2346],
|
||||||
|
[ 0.0447, 0.0994, 0.0881]
|
||||||
|
]
|
||||||
|
|
||||||
|
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
|
||||||
|
|||||||
808
comfy/ldm/cosmos/blocks.py
Normal file
808
comfy/ldm/cosmos/blocks.py
Normal file
@@ -0,0 +1,808 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(
|
||||||
|
t: torch.Tensor,
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||||
|
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||||
|
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||||
|
return t_out
|
||||||
|
|
||||||
|
|
||||||
|
def get_normalization(name: str, channels: int, weight_args={}):
|
||||||
|
if name == "I":
|
||||||
|
return nn.Identity()
|
||||||
|
elif name == "R":
|
||||||
|
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization {name} not found")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAttentionOp(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
Generalized attention impl.
|
||||||
|
|
||||||
|
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
|
||||||
|
If `context_dim` is None, self-attention is assumed.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
query_dim (int): Dimension of each query vector.
|
||||||
|
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
|
||||||
|
heads (int, optional): Number of attention heads. Defaults to 8.
|
||||||
|
dim_head (int, optional): Dimension of each head. Defaults to 64.
|
||||||
|
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
|
||||||
|
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
|
||||||
|
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
|
||||||
|
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
|
||||||
|
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
|
||||||
|
Defaults to "SSI".
|
||||||
|
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
|
||||||
|
Defaults to 'per_head'. Only support 'per_head'.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
|
||||||
|
>>> query = torch.randn(10, 128) # Batch size of 10
|
||||||
|
>>> context = torch.randn(10, 256) # Batch size of 10
|
||||||
|
>>> output = attn(query, context) # Perform the attention operation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
context_dim=None,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
attn_op: Optional[BaseAttentionOp] = None,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
out_bias: bool = False,
|
||||||
|
qkv_norm: str = "SSI",
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
backend: str = "transformer_engine",
|
||||||
|
qkv_format: str = "bshd",
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_selfattn = context_dim is None # self attention
|
||||||
|
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
context_dim = query_dim if context_dim is None else context_dim
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.qkv_norm_mode = qkv_norm_mode
|
||||||
|
self.qkv_format = qkv_format
|
||||||
|
|
||||||
|
if self.qkv_norm_mode == "per_head":
|
||||||
|
norm_dim = dim_head
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||||
|
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
self.to_q = nn.Sequential(
|
||||||
|
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[0], norm_dim),
|
||||||
|
)
|
||||||
|
self.to_k = nn.Sequential(
|
||||||
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[1], norm_dim),
|
||||||
|
)
|
||||||
|
self.to_v = nn.Sequential(
|
||||||
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[2], norm_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def cal_qkv(
|
||||||
|
self, x, context=None, mask=None, rope_emb=None, **kwargs
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
|
||||||
|
Before 07/24/2024, these modules normalize across all heads.
|
||||||
|
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
|
||||||
|
we support to normalize per head.
|
||||||
|
To keep the checkpoint copatibility with the previous code,
|
||||||
|
we keep the nn.Sequential but call the projection and the normalization layers separately.
|
||||||
|
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
|
||||||
|
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
|
||||||
|
"""
|
||||||
|
if self.qkv_norm_mode == "per_head":
|
||||||
|
q = self.to_q[0](x)
|
||||||
|
context = x if context is None else context
|
||||||
|
k = self.to_k[0](context)
|
||||||
|
v = self.to_v[0](context)
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||||
|
|
||||||
|
q = self.to_q[1](q)
|
||||||
|
k = self.to_k[1](k)
|
||||||
|
v = self.to_v[1](v)
|
||||||
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||||
|
# apply_rotary_pos_emb inlined
|
||||||
|
q_shape = q.shape
|
||||||
|
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||||
|
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
||||||
|
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
||||||
|
|
||||||
|
# apply_rotary_pos_emb inlined
|
||||||
|
k_shape = k.shape
|
||||||
|
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||||
|
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
||||||
|
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context=None,
|
||||||
|
mask=None,
|
||||||
|
rope_emb=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||||
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
|
"""
|
||||||
|
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||||
|
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||||
|
del q, k, v
|
||||||
|
out = rearrange(out, " b n s c -> s b (n c)")
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer FFN with optional gating
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
d_model (int): Dimensionality of input features.
|
||||||
|
d_ff (int): Dimensionality of the hidden layer.
|
||||||
|
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
|
||||||
|
activation (callable, optional): The activation function applied after the first linear layer.
|
||||||
|
Defaults to nn.ReLU().
|
||||||
|
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
|
||||||
|
Defaults to False.
|
||||||
|
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> ff = FeedForward(d_model=512, d_ff=2048)
|
||||||
|
>>> x = torch.randn(64, 10, 512) # Example input tensor
|
||||||
|
>>> output = ff(x)
|
||||||
|
>>> print(output.shape) # Expected shape: (64, 10, 512)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
d_ff: int,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
activation=nn.ReLU(),
|
||||||
|
is_gated: bool = False,
|
||||||
|
bias: bool = False,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
|
||||||
|
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.activation = activation
|
||||||
|
self.is_gated = is_gated
|
||||||
|
if is_gated:
|
||||||
|
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
g = self.activation(self.layer1(x))
|
||||||
|
if self.is_gated:
|
||||||
|
x = g * self.linear_gate(x)
|
||||||
|
else:
|
||||||
|
x = g
|
||||||
|
assert self.dropout.p == 0.0, "we skip dropout"
|
||||||
|
return self.layer2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2FeedForward(FeedForward):
|
||||||
|
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
|
||||||
|
super().__init__(
|
||||||
|
d_model=d_model,
|
||||||
|
d_ff=d_ff,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=nn.GELU(),
|
||||||
|
is_gated=False,
|
||||||
|
bias=bias,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
assert self.dropout.p == 0.0, "we skip dropout"
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / (half_dim - 0.0)
|
||||||
|
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
|
||||||
|
sin_emb = torch.sin(emb)
|
||||||
|
cos_emb = torch.cos(emb)
|
||||||
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||||
|
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
logging.debug(
|
||||||
|
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||||
|
)
|
||||||
|
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
|
||||||
|
else:
|
||||||
|
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = self.linear_1(sample)
|
||||||
|
emb = self.activation(emb)
|
||||||
|
emb = self.linear_2(emb)
|
||||||
|
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
adaln_lora_B_3D = emb
|
||||||
|
emb_B_D = sample
|
||||||
|
else:
|
||||||
|
emb_B_D = emb
|
||||||
|
adaln_lora_B_3D = None
|
||||||
|
|
||||||
|
return emb_B_D, adaln_lora_B_3D
|
||||||
|
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
|
||||||
|
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
|
||||||
|
|
||||||
|
[B] -> [B, D]
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_channels (int): The number of Fourier features to generate.
|
||||||
|
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
|
||||||
|
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
|
||||||
|
the variance of the features. Defaults to False.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
|
||||||
|
>>> x = torch.randn(10, 256) # Example input tensor
|
||||||
|
>>> output = layer(x)
|
||||||
|
>>> print(output.shape) # Expected shape: (10, 256)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_channels, bandwidth=1, normalize=False):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||||
|
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||||
|
self.gain = np.sqrt(2) if normalize else 1
|
||||||
|
|
||||||
|
def forward(self, x, gain: float = 1.0):
|
||||||
|
"""
|
||||||
|
Apply the Fourier feature transformation to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The transformed tensor, with Fourier features applied.
|
||||||
|
"""
|
||||||
|
in_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||||
|
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||||
|
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||||
|
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||||
|
and embedding each patch into a vector of size `out_channels`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- spatial_patch_size (int): The size of each spatial patch.
|
||||||
|
- temporal_patch_size (int): The size of each temporal patch.
|
||||||
|
- in_channels (int): Number of input channels. Default: 3.
|
||||||
|
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||||
|
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spatial_patch_size,
|
||||||
|
temporal_patch_size,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=768,
|
||||||
|
bias=True,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.spatial_patch_size = spatial_patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
Rearrange(
|
||||||
|
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||||
|
r=temporal_patch_size,
|
||||||
|
m=spatial_patch_size,
|
||||||
|
n=spatial_patch_size,
|
||||||
|
),
|
||||||
|
operations.Linear(
|
||||||
|
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.out = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the PatchEmbed module.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||||
|
B is the batch size,
|
||||||
|
C is the number of channels,
|
||||||
|
T is the temporal dimension,
|
||||||
|
H is the height, and
|
||||||
|
W is the width of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||||
|
"""
|
||||||
|
assert x.dim() == 5
|
||||||
|
_, _, T, H, W = x.shape
|
||||||
|
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||||
|
assert T % self.temporal_patch_size == 0
|
||||||
|
x = self.proj(x)
|
||||||
|
return self.out(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of video DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
spatial_patch_size,
|
||||||
|
temporal_patch_size,
|
||||||
|
out_channels,
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
|
||||||
|
self.linear = operations.Linear(
|
||||||
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
|
||||||
|
)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.n_adaln_chunks = 2
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
|
||||||
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_BT_HW_D,
|
||||||
|
emb_B_D,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
assert adaln_lora_B_3D is not None
|
||||||
|
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
|
||||||
|
2, dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
|
||||||
|
|
||||||
|
B = emb_B_D.shape[0]
|
||||||
|
T = x_BT_HW_D.shape[0] // B
|
||||||
|
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
|
||||||
|
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
|
||||||
|
|
||||||
|
x_BT_HW_D = self.linear(x_BT_HW_D)
|
||||||
|
return x_BT_HW_D
|
||||||
|
|
||||||
|
|
||||||
|
class VideoAttn(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements video attention with optional cross-attention capabilities.
|
||||||
|
|
||||||
|
This module processes video features while maintaining their spatio-temporal structure. It can perform
|
||||||
|
self-attention within the video features or cross-attention with external context features.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_dim (int): Dimension of input feature vectors
|
||||||
|
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
bias (bool): Whether to include bias in attention projections. Default: False
|
||||||
|
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
|
||||||
|
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
|
||||||
|
|
||||||
|
Input shape:
|
||||||
|
- x: (T, H, W, B, D) video features
|
||||||
|
- context (optional): (M, B, D) context features for cross-attention
|
||||||
|
where:
|
||||||
|
T: temporal dimension
|
||||||
|
H: height
|
||||||
|
W: width
|
||||||
|
B: batch size
|
||||||
|
D: feature dimension
|
||||||
|
M: context sequence length
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: Optional[int],
|
||||||
|
num_heads: int,
|
||||||
|
bias: bool = False,
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.x_format = x_format
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
x_dim // num_heads,
|
||||||
|
qkv_bias=bias,
|
||||||
|
qkv_norm="RRI",
|
||||||
|
out_bias=bias,
|
||||||
|
qkv_norm_mode=qkv_norm_mode,
|
||||||
|
qkv_format="sbhd",
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for video attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
|
||||||
|
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
|
||||||
|
where M is the sequence length of the context.
|
||||||
|
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
|
||||||
|
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||||
|
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor with applied attention, maintaining the input shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
x_T_H_W_B_D = x
|
||||||
|
context_M_B_D = context
|
||||||
|
T, H, W, B, D = x_T_H_W_B_D.shape
|
||||||
|
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
|
||||||
|
x_THW_B_D = self.attn(
|
||||||
|
x_THW_B_D,
|
||||||
|
context_M_B_D,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||||
|
return x_T_H_W_B_D
|
||||||
|
|
||||||
|
|
||||||
|
def adaln_norm_state(norm_state, x, scale, shift):
|
||||||
|
normalized = norm_state(x)
|
||||||
|
return normalized * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
|
class DITBuildingBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
|
||||||
|
attention and MLP operations with adaptive layer normalization.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
block_type (str): Type of block - one of:
|
||||||
|
- "cross_attn"/"ca": Cross-attention
|
||||||
|
- "full_attn"/"fa": Full self-attention
|
||||||
|
- "mlp"/"ff": MLP/feedforward block
|
||||||
|
x_dim (int): Dimension of input features
|
||||||
|
context_dim (Optional[int]): Dimension of context features for cross-attention
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||||
|
bias (bool): Whether to use bias in layers. Default: False
|
||||||
|
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
|
||||||
|
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
|
||||||
|
x_format (str): Input tensor format. Default: "BTHWD"
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_type: str,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: Optional[int],
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
bias: bool = False,
|
||||||
|
mlp_dropout: float = 0.0,
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None
|
||||||
|
) -> None:
|
||||||
|
block_type = block_type.lower()
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.x_format = x_format
|
||||||
|
if block_type in ["cross_attn", "ca"]:
|
||||||
|
self.block = VideoAttn(
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
bias=bias,
|
||||||
|
qkv_norm_mode=qkv_norm_mode,
|
||||||
|
x_format=self.x_format,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
elif block_type in ["full_attn", "fa"]:
|
||||||
|
self.block = VideoAttn(
|
||||||
|
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
|
||||||
|
)
|
||||||
|
elif block_type in ["mlp", "ff"]:
|
||||||
|
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown block type: {block_type}")
|
||||||
|
|
||||||
|
self.block_type = block_type
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
|
||||||
|
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.n_adaln_chunks = 3
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
|
||||||
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
|
||||||
|
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
|
||||||
|
crossattn_emb (Tensor): Tensor for cross-attention blocks.
|
||||||
|
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
|
||||||
|
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||||
|
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor after processing through the configured block and adaptive normalization.
|
||||||
|
"""
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
|
||||||
|
self.n_adaln_chunks, dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
|
||||||
|
|
||||||
|
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
|
||||||
|
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.block_type in ["mlp", "ff"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
)
|
||||||
|
elif self.block_type in ["full_attn", "fa"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
context=None,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
elif self.block_type in ["cross_attn", "ca"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
context=crossattn_emb,
|
||||||
|
crossattn_mask=crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralDITTransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
|
||||||
|
Each block in the sequence is specified by a block configuration string.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_dim (int): Dimension of input features
|
||||||
|
context_dim (int): Dimension of context features for cross-attention blocks
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
|
||||||
|
full-attention, then MLP)
|
||||||
|
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||||
|
x_format (str): Input tensor format. Default: "BTHWD"
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||||
|
|
||||||
|
The block_config string uses "-" to separate block types:
|
||||||
|
- "ca"/"cross_attn": Cross-attention block
|
||||||
|
- "fa"/"full_attn": Full self-attention block
|
||||||
|
- "mlp"/"ff": MLP/feedforward block
|
||||||
|
|
||||||
|
Example:
|
||||||
|
block_config = "ca-fa-mlp" creates a sequence of:
|
||||||
|
1. Cross-attention block
|
||||||
|
2. Full self-attention block
|
||||||
|
3. MLP block
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
block_config: str,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
self.x_format = x_format
|
||||||
|
for block_type in block_config.split("-"):
|
||||||
|
self.blocks.append(
|
||||||
|
DITBuildingBlock(
|
||||||
|
block_type,
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio,
|
||||||
|
x_format=self.x_format,
|
||||||
|
use_adaln_lora=use_adaln_lora,
|
||||||
|
adaln_lora_dim=adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
emb_B_D,
|
||||||
|
crossattn_emb,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
|
return x
|
||||||
1041
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
1041
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
File diff suppressed because it is too large
Load Diff
377
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
377
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""The patcher and unpatcher implementation for 2D and 3D data.
|
||||||
|
|
||||||
|
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
|
||||||
|
One on the rows and one on the columns.
|
||||||
|
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
|
||||||
|
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
|
||||||
|
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
|
||||||
|
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
|
||||||
|
as we need to support downsampling for more than 2x.
|
||||||
|
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
|
||||||
|
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
_WAVELETS = {
|
||||||
|
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
|
||||||
|
"rearrange": torch.tensor([1.0, 1.0]),
|
||||||
|
}
|
||||||
|
_PERSISTENT = False
|
||||||
|
|
||||||
|
|
||||||
|
class Patcher(torch.nn.Module):
|
||||||
|
"""A module to convert image tensors into patches using torch operations.
|
||||||
|
|
||||||
|
The main difference from `class Patching` is that this module implements
|
||||||
|
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||||
|
|
||||||
|
It's bit-wise identical to the Patching module outputs, with the added
|
||||||
|
benefit of being torch.jit scriptable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_method = patch_method
|
||||||
|
self.register_buffer(
|
||||||
|
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||||
|
)
|
||||||
|
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||||
|
self.register_buffer(
|
||||||
|
"_arange",
|
||||||
|
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.patch_method == "haar":
|
||||||
|
return self._haar(x)
|
||||||
|
elif self.patch_method == "rearrange":
|
||||||
|
return self._arrange(x)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||||
|
|
||||||
|
def _dwt(self, x, mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
n = h.shape[0]
|
||||||
|
g = x.shape[1]
|
||||||
|
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
||||||
|
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
|
||||||
|
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
|
||||||
|
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
|
||||||
|
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
|
||||||
|
if rescale:
|
||||||
|
out = out / 2
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _haar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._dwt(x, rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _arrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (h p1) (w p2) -> b (c p1 p2) h w",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Patcher3D(Patcher):
|
||||||
|
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||||
|
self.register_buffer(
|
||||||
|
"patch_size_buffer",
|
||||||
|
patch_size * torch.ones([1], dtype=torch.int32),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
n = h.shape[0]
|
||||||
|
g = x.shape[1]
|
||||||
|
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
# Handles temporal axis.
|
||||||
|
x = F.pad(
|
||||||
|
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
|
||||||
|
).to(dtype)
|
||||||
|
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||||
|
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||||
|
|
||||||
|
# Handles spatial axes.
|
||||||
|
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
|
||||||
|
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
|
||||||
|
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
|
||||||
|
if rescale:
|
||||||
|
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _haar(self, x):
|
||||||
|
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||||
|
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._dwt(x, "haar", rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _arrange(self, x):
|
||||||
|
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||||
|
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
p3=self.patch_size,
|
||||||
|
).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UnPatcher(torch.nn.Module):
|
||||||
|
"""A module to convert patches into image tensorsusing torch operations.
|
||||||
|
|
||||||
|
The main difference from `class Unpatching` is that this module implements
|
||||||
|
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||||
|
|
||||||
|
It's bit-wise identical to the Unpatching module outputs, with the added
|
||||||
|
benefit of being torch.jit scriptable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_method = patch_method
|
||||||
|
self.register_buffer(
|
||||||
|
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||||
|
)
|
||||||
|
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||||
|
self.register_buffer(
|
||||||
|
"_arange",
|
||||||
|
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.patch_method == "haar":
|
||||||
|
return self._ihaar(x)
|
||||||
|
elif self.patch_method == "rearrange":
|
||||||
|
return self._iarrange(x)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||||
|
|
||||||
|
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
n = h.shape[0]
|
||||||
|
|
||||||
|
g = x.shape[1] // 4
|
||||||
|
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
|
||||||
|
|
||||||
|
# Inverse transform.
|
||||||
|
yl = torch.nn.functional.conv_transpose2d(
|
||||||
|
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yl += torch.nn.functional.conv_transpose2d(
|
||||||
|
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yh = torch.nn.functional.conv_transpose2d(
|
||||||
|
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yh += torch.nn.functional.conv_transpose2d(
|
||||||
|
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
y = torch.nn.functional.conv_transpose2d(
|
||||||
|
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||||
|
)
|
||||||
|
y += torch.nn.functional.conv_transpose2d(
|
||||||
|
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale:
|
||||||
|
y = y * 2
|
||||||
|
return y
|
||||||
|
|
||||||
|
def _ihaar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._idwt(x, "haar", rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _iarrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UnPatcher3D(UnPatcher):
|
||||||
|
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||||
|
|
||||||
|
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
|
||||||
|
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
|
||||||
|
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||||
|
del x
|
||||||
|
|
||||||
|
# Height height transposed convolutions.
|
||||||
|
xll = F.conv_transpose3d(
|
||||||
|
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xlll
|
||||||
|
|
||||||
|
xll += F.conv_transpose3d(
|
||||||
|
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xllh
|
||||||
|
|
||||||
|
xlh = F.conv_transpose3d(
|
||||||
|
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xlhl
|
||||||
|
|
||||||
|
xlh += F.conv_transpose3d(
|
||||||
|
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xlhh
|
||||||
|
|
||||||
|
xhl = F.conv_transpose3d(
|
||||||
|
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xhll
|
||||||
|
|
||||||
|
xhl += F.conv_transpose3d(
|
||||||
|
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xhlh
|
||||||
|
|
||||||
|
xhh = F.conv_transpose3d(
|
||||||
|
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xhhl
|
||||||
|
|
||||||
|
xhh += F.conv_transpose3d(
|
||||||
|
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
del xhhh
|
||||||
|
|
||||||
|
# Handles width transposed convolutions.
|
||||||
|
xl = F.conv_transpose3d(
|
||||||
|
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
del xll
|
||||||
|
|
||||||
|
xl += F.conv_transpose3d(
|
||||||
|
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
del xlh
|
||||||
|
|
||||||
|
xh = F.conv_transpose3d(
|
||||||
|
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
del xhl
|
||||||
|
|
||||||
|
xh += F.conv_transpose3d(
|
||||||
|
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
del xhh
|
||||||
|
|
||||||
|
# Handles time axis transposed convolutions.
|
||||||
|
x = F.conv_transpose3d(
|
||||||
|
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
|
)
|
||||||
|
del xl
|
||||||
|
|
||||||
|
x += F.conv_transpose3d(
|
||||||
|
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale:
|
||||||
|
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _ihaar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._idwt(x, "haar", rescale=True)
|
||||||
|
x = x[:, :, self.patch_size - 1 :, ...]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _iarrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
p3=self.patch_size,
|
||||||
|
)
|
||||||
|
x = x[:, :, self.patch_size - 1 :, ...]
|
||||||
|
return x
|
||||||
112
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
112
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Shared utilities for the networks module."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||||
|
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||||
|
batch_size, height = x.shape[0], x.shape[-2]
|
||||||
|
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
|
||||||
|
|
||||||
|
|
||||||
|
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
|
||||||
|
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(t: Any, length: int = 1) -> Any:
|
||||||
|
return t if isinstance(t, tuple) else ((t,) * length)
|
||||||
|
|
||||||
|
|
||||||
|
def replication_pad(x):
|
||||||
|
return torch.cat([x[:, :, :1, ...], x], dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
def divisible_by(num: int, den: int) -> bool:
|
||||||
|
return (num % den) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_odd(n: int) -> bool:
|
||||||
|
return not divisible_by(n, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def nonlinearity(x):
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
return ops.GroupNorm(
|
||||||
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalNormalize(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, num_groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = ops.GroupNorm(
|
||||||
|
num_groups=num_groups,
|
||||||
|
num_channels=in_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True,
|
||||||
|
)
|
||||||
|
self.num_groups = num_groups
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
|
||||||
|
# All new models should use num_groups=1, otherwise causality is not guaranteed.
|
||||||
|
if self.num_groups == 1:
|
||||||
|
x, batch_size = time2batch(x)
|
||||||
|
return batch2time(self.norm(x), batch_size)
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
def exists(v):
|
||||||
|
return v is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(*args):
|
||||||
|
for arg in args:
|
||||||
|
if exists(arg):
|
||||||
|
return arg
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def round_ste(z: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Round with straight through gradients."""
|
||||||
|
zhat = z.round()
|
||||||
|
return z + (zhat - z).detach()
|
||||||
|
|
||||||
|
|
||||||
|
def log(t, eps=1e-5):
|
||||||
|
return t.clamp(min=eps).log()
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(prob):
|
||||||
|
return (-prob * log(prob)).sum(dim=-1)
|
||||||
514
comfy/ldm/cosmos/model.py
Normal file
514
comfy/ldm/cosmos/model.py
Normal file
@@ -0,0 +1,514 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||||
|
|
||||||
|
from .blocks import (
|
||||||
|
FinalLayer,
|
||||||
|
GeneralDITTransformerBlock,
|
||||||
|
PatchEmbed,
|
||||||
|
TimestepEmbedding,
|
||||||
|
Timesteps,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(Enum):
|
||||||
|
IMAGE = "image"
|
||||||
|
VIDEO = "video"
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralDIT(nn.Module):
|
||||||
|
"""
|
||||||
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_img_h (int): Maximum height of the input images.
|
||||||
|
max_img_w (int): Maximum width of the input images.
|
||||||
|
max_frames (int): Maximum number of frames in the video sequence.
|
||||||
|
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||||
|
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||||
|
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||||
|
block_config (str): Configuration of the transformer block. See Notes for supported block types.
|
||||||
|
model_channels (int): Base number of channels used throughout the model.
|
||||||
|
num_blocks (int): Number of transformer blocks.
|
||||||
|
num_heads (int): Number of heads in the multi-head attention layers.
|
||||||
|
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||||
|
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
|
||||||
|
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||||
|
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
|
||||||
|
pos_emb_cls (str): Type of positional embeddings.
|
||||||
|
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||||
|
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||||
|
affline_emb_norm (bool): Whether to normalize affine embeddings.
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||||
|
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||||
|
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||||
|
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||||
|
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||||
|
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
|
||||||
|
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||||
|
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||||
|
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Supported block types in block_config:
|
||||||
|
* cross_attn, ca: Cross attention
|
||||||
|
* full_attn: Full attention on all flattened tokens
|
||||||
|
* mlp, ff: Feed forward block
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_img_h: int,
|
||||||
|
max_img_w: int,
|
||||||
|
max_frames: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
patch_spatial: tuple,
|
||||||
|
patch_temporal: int,
|
||||||
|
concat_padding_mask: bool = True,
|
||||||
|
# attention settings
|
||||||
|
block_config: str = "FA-CA-MLP",
|
||||||
|
model_channels: int = 768,
|
||||||
|
num_blocks: int = 10,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
block_x_format: str = "BTHWD",
|
||||||
|
# cross attention settings
|
||||||
|
crossattn_emb_channels: int = 1024,
|
||||||
|
use_cross_attn_mask: bool = False,
|
||||||
|
# positional embedding settings
|
||||||
|
pos_emb_cls: str = "sincos",
|
||||||
|
pos_emb_learnable: bool = False,
|
||||||
|
pos_emb_interpolation: str = "crop",
|
||||||
|
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
rope_h_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_w_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_t_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_per_block_abs_pos_emb: bool = False,
|
||||||
|
extra_per_block_abs_pos_emb_type: str = "sincos",
|
||||||
|
extra_h_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_w_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_t_extrapolation_ratio: float = 1.0,
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.max_img_h = max_img_h
|
||||||
|
self.max_img_w = max_img_w
|
||||||
|
self.max_frames = max_frames
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_spatial = patch_spatial
|
||||||
|
self.patch_temporal = patch_temporal
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.use_cross_attn_mask = use_cross_attn_mask
|
||||||
|
self.concat_padding_mask = concat_padding_mask
|
||||||
|
# positional embedding settings
|
||||||
|
self.pos_emb_cls = pos_emb_cls
|
||||||
|
self.pos_emb_learnable = pos_emb_learnable
|
||||||
|
self.pos_emb_interpolation = pos_emb_interpolation
|
||||||
|
self.affline_emb_norm = affline_emb_norm
|
||||||
|
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||||
|
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||||
|
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||||
|
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||||
|
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
|
||||||
|
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||||
|
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||||
|
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||||
|
self.dtype = dtype
|
||||||
|
weight_args = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
spatial_patch_size=patch_spatial,
|
||||||
|
temporal_patch_size=patch_temporal,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=model_channels,
|
||||||
|
bias=False,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.build_pos_embed(device=device, dtype=dtype)
|
||||||
|
self.block_x_format = block_x_format
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
self.adaln_lora_dim = adaln_lora_dim
|
||||||
|
self.t_embedder = nn.ModuleList(
|
||||||
|
[Timesteps(model_channels),
|
||||||
|
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleDict()
|
||||||
|
|
||||||
|
for idx in range(num_blocks):
|
||||||
|
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
|
||||||
|
x_dim=model_channels,
|
||||||
|
context_dim=crossattn_emb_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
block_config=block_config,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
x_format=self.block_x_format,
|
||||||
|
use_adaln_lora=use_adaln_lora,
|
||||||
|
adaln_lora_dim=adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.affline_emb_norm:
|
||||||
|
logging.debug("Building affine embedding normalization layer")
|
||||||
|
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
else:
|
||||||
|
self.affline_norm = nn.Identity()
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
hidden_size=self.model_channels,
|
||||||
|
spatial_patch_size=self.patch_spatial,
|
||||||
|
temporal_patch_size=self.patch_temporal,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
use_adaln_lora=self.use_adaln_lora,
|
||||||
|
adaln_lora_dim=self.adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_pos_embed(self, device=None, dtype=None):
|
||||||
|
if self.pos_emb_cls == "rope3d":
|
||||||
|
cls_type = VideoRopePosition3DEmb
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||||
|
|
||||||
|
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||||
|
kwargs = dict(
|
||||||
|
model_channels=self.model_channels,
|
||||||
|
len_h=self.max_img_h // self.patch_spatial,
|
||||||
|
len_w=self.max_img_w // self.patch_spatial,
|
||||||
|
len_t=self.max_frames // self.patch_temporal,
|
||||||
|
is_learnable=self.pos_emb_learnable,
|
||||||
|
interpolation=self.pos_emb_interpolation,
|
||||||
|
head_dim=self.model_channels // self.num_heads,
|
||||||
|
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||||
|
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||||
|
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.pos_embedder = cls_type(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.extra_per_block_abs_pos_emb:
|
||||||
|
assert self.extra_per_block_abs_pos_emb_type in [
|
||||||
|
"learnable",
|
||||||
|
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
|
||||||
|
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||||
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||||
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||||
|
kwargs["device"] = device
|
||||||
|
kwargs["dtype"] = dtype
|
||||||
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_embedded_sequence(
|
||||||
|
self,
|
||||||
|
x_B_C_T_H_W: torch.Tensor,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_B_C_T_H_W (torch.Tensor): video
|
||||||
|
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||||
|
If None, a default value (`self.base_fps`) will be used.
|
||||||
|
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||||
|
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||||
|
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||||
|
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||||
|
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||||
|
the `self.pos_embedder` with the shape [T, H, W].
|
||||||
|
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||||
|
`self.pos_embedder` with the fps tensor.
|
||||||
|
- Otherwise, the positional embeddings are generated without considering fps.
|
||||||
|
"""
|
||||||
|
if self.concat_padding_mask:
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = transforms.functional.resize(
|
||||||
|
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||||
|
|
||||||
|
x_B_C_T_H_W = torch.cat(
|
||||||
|
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||||
|
)
|
||||||
|
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||||
|
|
||||||
|
if self.extra_per_block_abs_pos_emb:
|
||||||
|
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||||
|
else:
|
||||||
|
extra_pos_emb = None
|
||||||
|
|
||||||
|
if "rope" in self.pos_emb_cls.lower():
|
||||||
|
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||||
|
|
||||||
|
if "fps_aware" in self.pos_emb_cls:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||||
|
else:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||||
|
|
||||||
|
return x_B_T_H_W_D, None, extra_pos_emb
|
||||||
|
|
||||||
|
def decoder_head(
|
||||||
|
self,
|
||||||
|
x_B_T_H_W_D: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
del crossattn_emb, crossattn_mask
|
||||||
|
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
|
||||||
|
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
|
||||||
|
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
|
||||||
|
# This is to ensure x_BT_HW_D has the correct shape because
|
||||||
|
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
|
||||||
|
x_BT_HW_D = x_BT_HW_D.view(
|
||||||
|
B * T_before_patchify // self.patch_temporal,
|
||||||
|
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
x_B_D_T_H_W = rearrange(
|
||||||
|
x_BT_HW_D,
|
||||||
|
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||||
|
p1=self.patch_spatial,
|
||||||
|
p2=self.patch_spatial,
|
||||||
|
H=H_before_patchify // self.patch_spatial,
|
||||||
|
W=W_before_patchify // self.patch_spatial,
|
||||||
|
t=self.patch_temporal,
|
||||||
|
B=B,
|
||||||
|
)
|
||||||
|
return x_B_D_T_H_W
|
||||||
|
|
||||||
|
def forward_before_blocks(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
image_size: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
scalar_feature: Optional[torch.Tensor] = None,
|
||||||
|
data_type: Optional[DataType] = DataType.VIDEO,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||||
|
timesteps: (B, ) tensor of timesteps
|
||||||
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||||
|
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||||
|
"""
|
||||||
|
del kwargs
|
||||||
|
assert isinstance(
|
||||||
|
data_type, DataType
|
||||||
|
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
|
||||||
|
original_shape = x.shape
|
||||||
|
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||||
|
x,
|
||||||
|
fps=fps,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
latent_condition=latent_condition,
|
||||||
|
latent_condition_sigma=latent_condition_sigma,
|
||||||
|
)
|
||||||
|
# logging affline scale information
|
||||||
|
affline_scale_log_info = {}
|
||||||
|
|
||||||
|
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
|
||||||
|
affline_emb_B_D = timesteps_B_D
|
||||||
|
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
|
||||||
|
|
||||||
|
if scalar_feature is not None:
|
||||||
|
raise NotImplementedError("Scalar feature is not implemented yet.")
|
||||||
|
|
||||||
|
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
|
||||||
|
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
|
||||||
|
|
||||||
|
if self.use_cross_attn_mask:
|
||||||
|
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
|
||||||
|
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||||
|
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
|
||||||
|
else:
|
||||||
|
crossattn_mask = None
|
||||||
|
|
||||||
|
if self.blocks["block0"].x_format == "THWBD":
|
||||||
|
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
|
||||||
|
)
|
||||||
|
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
|
||||||
|
|
||||||
|
if crossattn_mask:
|
||||||
|
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
|
||||||
|
|
||||||
|
elif self.blocks["block0"].x_format == "BTHWD":
|
||||||
|
x = x_B_T_H_W_D
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
|
||||||
|
output = {
|
||||||
|
"x": x,
|
||||||
|
"affline_emb_B_D": affline_emb_B_D,
|
||||||
|
"crossattn_emb": crossattn_emb,
|
||||||
|
"crossattn_mask": crossattn_mask,
|
||||||
|
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
|
||||||
|
"adaln_lora_B_3D": adaln_lora_B_3D,
|
||||||
|
"original_shape": original_shape,
|
||||||
|
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
# crossattn_emb: torch.Tensor,
|
||||||
|
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
image_size: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
scalar_feature: Optional[torch.Tensor] = None,
|
||||||
|
data_type: Optional[DataType] = DataType.VIDEO,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||||
|
timesteps: (B, ) tensor of timesteps
|
||||||
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||||
|
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||||
|
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
|
||||||
|
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
|
||||||
|
we need forward_before_blocks pass to the forward_before_blocks function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
crossattn_emb = context
|
||||||
|
crossattn_mask = attention_mask
|
||||||
|
|
||||||
|
inputs = self.forward_before_blocks(
|
||||||
|
x=x,
|
||||||
|
timesteps=timesteps,
|
||||||
|
crossattn_emb=crossattn_emb,
|
||||||
|
crossattn_mask=crossattn_mask,
|
||||||
|
fps=fps,
|
||||||
|
image_size=image_size,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
scalar_feature=scalar_feature,
|
||||||
|
data_type=data_type,
|
||||||
|
latent_condition=latent_condition,
|
||||||
|
latent_condition_sigma=latent_condition_sigma,
|
||||||
|
condition_video_augment_sigma=condition_video_augment_sigma,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
|
||||||
|
inputs["x"],
|
||||||
|
inputs["affline_emb_B_D"],
|
||||||
|
inputs["crossattn_emb"],
|
||||||
|
inputs["crossattn_mask"],
|
||||||
|
inputs["rope_emb_L_1_1_D"],
|
||||||
|
inputs["adaln_lora_B_3D"],
|
||||||
|
inputs["original_shape"],
|
||||||
|
)
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||||
|
del inputs
|
||||||
|
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
assert (
|
||||||
|
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
|
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||||
|
|
||||||
|
for _, block in self.blocks.items():
|
||||||
|
assert (
|
||||||
|
self.blocks["block0"].x_format == block.x_format
|
||||||
|
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||||
|
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
affline_emb_B_D,
|
||||||
|
crossattn_emb,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||||
|
|
||||||
|
x_B_D_T_H_W = self.decoder_head(
|
||||||
|
x_B_T_H_W_D=x_B_T_H_W_D,
|
||||||
|
emb_B_D=affline_emb_B_D,
|
||||||
|
crossattn_emb=None,
|
||||||
|
origin_shape=original_shape,
|
||||||
|
crossattn_mask=None,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
|
|
||||||
|
return x_B_D_T_H_W
|
||||||
208
comfy/ldm/cosmos/position_embedding.py
Normal file
208
comfy/ldm/cosmos/position_embedding.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor to normalize.
|
||||||
|
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
||||||
|
eps (float, optional): A small constant to ensure numerical stability during division.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
"""
|
||||||
|
if dim is None:
|
||||||
|
dim = list(range(1, x.ndim))
|
||||||
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||||
|
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||||
|
return x / norm.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPositionEmb(nn.Module):
|
||||||
|
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
It delegates the embedding generation to generate_embeddings function.
|
||||||
|
"""
|
||||||
|
B_T_H_W_C = x_B_T_H_W_C.shape
|
||||||
|
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*, # enforce keyword arguments
|
||||||
|
head_dim: int,
|
||||||
|
len_h: int,
|
||||||
|
len_w: int,
|
||||||
|
len_t: int,
|
||||||
|
base_fps: int = 24,
|
||||||
|
h_extrapolation_ratio: float = 1.0,
|
||||||
|
w_extrapolation_ratio: float = 1.0,
|
||||||
|
t_extrapolation_ratio: float = 1.0,
|
||||||
|
device=None,
|
||||||
|
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||||
|
):
|
||||||
|
del kwargs
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
||||||
|
self.base_fps = base_fps
|
||||||
|
self.max_h = len_h
|
||||||
|
self.max_w = len_w
|
||||||
|
|
||||||
|
dim = head_dim
|
||||||
|
dim_h = dim // 6 * 2
|
||||||
|
dim_w = dim_h
|
||||||
|
dim_t = dim - 2 * dim_h
|
||||||
|
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
||||||
|
self.register_buffer(
|
||||||
|
"dim_spatial_range",
|
||||||
|
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"dim_temporal_range",
|
||||||
|
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
||||||
|
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
||||||
|
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self,
|
||||||
|
B_T_H_W_C: torch.Size,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
h_ntk_factor: Optional[float] = None,
|
||||||
|
w_ntk_factor: Optional[float] = None,
|
||||||
|
t_ntk_factor: Optional[float] = None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate embeddings for the given input size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
||||||
|
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
||||||
|
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
||||||
|
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
||||||
|
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Not specified in the original code snippet.
|
||||||
|
"""
|
||||||
|
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
||||||
|
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
||||||
|
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
||||||
|
|
||||||
|
h_theta = 10000.0 * h_ntk_factor
|
||||||
|
w_theta = 10000.0 * w_ntk_factor
|
||||||
|
t_theta = 10000.0 * t_ntk_factor
|
||||||
|
|
||||||
|
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
||||||
|
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
||||||
|
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||||
|
|
||||||
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
|
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||||
|
assert (
|
||||||
|
uniform_fps or B == 1 or T == 1
|
||||||
|
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||||
|
assert (
|
||||||
|
H <= self.max_h and W <= self.max_w
|
||||||
|
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
||||||
|
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
||||||
|
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||||
|
|
||||||
|
# apply sequence scaling in temporal dimension
|
||||||
|
if fps is None: # image case
|
||||||
|
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||||
|
else:
|
||||||
|
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||||
|
|
||||||
|
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||||
|
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||||
|
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
||||||
|
|
||||||
|
em_T_H_W_D = torch.cat(
|
||||||
|
[
|
||||||
|
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
||||||
|
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
||||||
|
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
||||||
|
]
|
||||||
|
, dim=-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
||||||
|
|
||||||
|
|
||||||
|
class LearnablePosEmbAxis(VideoPositionEmb):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*, # enforce keyword arguments
|
||||||
|
interpolation: str,
|
||||||
|
model_channels: int,
|
||||||
|
len_h: int,
|
||||||
|
len_w: int,
|
||||||
|
len_t: int,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
||||||
|
"""
|
||||||
|
del kwargs # unused
|
||||||
|
super().__init__()
|
||||||
|
self.interpolation = interpolation
|
||||||
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||||
|
|
||||||
|
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
||||||
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||||
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||||
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
|
if self.interpolation == "crop":
|
||||||
|
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
|
||||||
|
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
|
||||||
|
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
|
||||||
|
emb = (
|
||||||
|
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
||||||
|
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
||||||
|
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
||||||
|
)
|
||||||
|
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
||||||
|
|
||||||
|
return normalize(emb, dim=-1, eps=1e-6)
|
||||||
131
comfy/ldm/cosmos/vae.py
Normal file
131
comfy/ldm/cosmos/vae.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from enum import Enum
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .cosmos_tokenizer.layers3d import (
|
||||||
|
EncoderFactorized,
|
||||||
|
DecoderFactorized,
|
||||||
|
CausalConv3d,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityDistribution(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, parameters):
|
||||||
|
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianDistribution(torch.nn.Module):
|
||||||
|
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
|
||||||
|
super().__init__()
|
||||||
|
self.min_logvar = min_logvar
|
||||||
|
self.max_logvar = max_logvar
|
||||||
|
|
||||||
|
def sample(self, mean, logvar):
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
return mean + std * torch.randn_like(mean)
|
||||||
|
|
||||||
|
def forward(self, parameters):
|
||||||
|
mean, logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
|
||||||
|
return self.sample(mean, logvar), (mean, logvar)
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousFormulation(Enum):
|
||||||
|
VAE = GaussianDistribution
|
||||||
|
AE = IdentityDistribution
|
||||||
|
|
||||||
|
|
||||||
|
class CausalContinuousVideoTokenizer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
self.sigma_data = 0.5
|
||||||
|
|
||||||
|
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
|
||||||
|
self.encoder = EncoderFactorized(
|
||||||
|
z_channels=z_factor * z_channels, **kwargs
|
||||||
|
)
|
||||||
|
if kwargs.get("temporal_compression", 4) == 4:
|
||||||
|
kwargs["channels_mult"] = [2, 4]
|
||||||
|
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
|
||||||
|
self.decoder = DecoderFactorized(
|
||||||
|
z_channels=z_channels, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = CausalConv3d(
|
||||||
|
z_factor * z_channels,
|
||||||
|
z_factor * latent_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.post_quant_conv = CausalConv3d(
|
||||||
|
latent_channels, z_channels, kernel_size=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
|
||||||
|
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
||||||
|
|
||||||
|
num_parameters = sum(param.numel() for param in self.parameters())
|
||||||
|
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||||
|
logging.debug(
|
||||||
|
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_temporal_chunk = 16
|
||||||
|
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||||
|
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
z, posteriors = self.distribution(moments)
|
||||||
|
latent_ch = z.shape[1]
|
||||||
|
latent_t = z.shape[2]
|
||||||
|
in_dtype = z.dtype
|
||||||
|
mean = self.latent_mean.view(latent_ch, -1)
|
||||||
|
std = self.latent_std.view(latent_ch, -1)
|
||||||
|
|
||||||
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
return ((z - mean) / std) * self.sigma_data
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
in_dtype = z.dtype
|
||||||
|
latent_ch = z.shape[1]
|
||||||
|
latent_t = z.shape[2]
|
||||||
|
mean = self.latent_mean.view(latent_ch, -1)
|
||||||
|
std = self.latent_std.view(latent_ch, -1)
|
||||||
|
|
||||||
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
|
||||||
|
z = z / self.sigma_data
|
||||||
|
z = z * std + mean
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
@@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|||||||
@@ -5,8 +5,15 @@ from torch import Tensor
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||||
q, k = apply_rope(q, k, pe)
|
q_shape = q.shape
|
||||||
|
k_shape = k.shape
|
||||||
|
|
||||||
|
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||||
|
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||||
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||||
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
|
|||||||
@@ -109,9 +109,8 @@ class Flux(nn.Module):
|
|||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is None:
|
if guidance is not None:
|
||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
|
||||||
|
|
||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
@@ -186,7 +185,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|||||||
@@ -240,9 +240,8 @@ class HunyuanVideo(nn.Module):
|
|||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is None:
|
if guidance is not None:
|
||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
|
||||||
|
|
||||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||||
@@ -314,7 +313,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = img.reshape(initial_shape)
|
img = img.reshape(initial_shape)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
|||||||
@@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
|
|||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
timestep = timestep * 1000.0
|
timestep = timestep * 1000.0
|
||||||
|
|
||||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -142,16 +142,23 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
out = (
|
|
||||||
out.unsqueeze(0)
|
if skip_output_reshape:
|
||||||
.reshape(b, heads, -1, dim_head)
|
out = (
|
||||||
.permute(0, 2, 1, 3)
|
out.unsqueeze(0)
|
||||||
.reshape(b, -1, heads * dim_head)
|
.reshape(b, heads, -1, dim_head)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
if skip_output_reshape:
|
||||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads))
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -326,12 +335,18 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r1 = (
|
if skip_output_reshape:
|
||||||
r1.unsqueeze(0)
|
r1 = (
|
||||||
.reshape(b, heads, -1, dim_head)
|
r1.unsqueeze(0)
|
||||||
.permute(0, 2, 1, 3)
|
.reshape(b, heads, -1, dim_head)
|
||||||
.reshape(b, -1, heads * dim_head)
|
)
|
||||||
)
|
else:
|
||||||
|
r1 = (
|
||||||
|
r1.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
return r1
|
return r1
|
||||||
|
|
||||||
BROKEN_XFORMERS = False
|
BROKEN_XFORMERS = False
|
||||||
@@ -342,7 +357,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
b = q.shape[0]
|
b = q.shape[0]
|
||||||
dim_head = q.shape[-1]
|
dim_head = q.shape[-1]
|
||||||
# check to make sure xformers isn't broken
|
# check to make sure xformers isn't broken
|
||||||
@@ -395,9 +410,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
out = (
|
if skip_output_reshape:
|
||||||
out.reshape(b, -1, heads * dim_head)
|
out = out.permute(0, 2, 1, 3)
|
||||||
)
|
else:
|
||||||
|
out = (
|
||||||
|
out.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -408,7 +426,7 @@ else:
|
|||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
@@ -429,9 +447,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
out = (
|
if not skip_output_reshape:
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out = (
|
||||||
)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||||
for i in range(0, b, SDP_BATCH_LIMIT):
|
for i in range(0, b, SDP_BATCH_LIMIT):
|
||||||
@@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout="HND"
|
tensor_layout="HND"
|
||||||
@@ -473,11 +492,15 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
|
|
||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
out = (
|
if not skip_output_reshape:
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out = (
|
||||||
)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
out = out.reshape(b, -1, heads * dim_head)
|
if skip_output_reshape:
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
else:
|
||||||
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def vae_attention():
|
||||||
|
if model_management.xformers_enabled_vae():
|
||||||
|
logging.info("Using xformers attention in VAE")
|
||||||
|
return xformers_attention
|
||||||
|
elif model_management.pytorch_attention_enabled():
|
||||||
|
logging.info("Using pytorch attention in VAE")
|
||||||
|
return pytorch_attention
|
||||||
|
else:
|
||||||
|
logging.info("Using split attention in VAE")
|
||||||
|
return normal_attention
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
self.optimized_attention = vae_attention()
|
||||||
logging.info("Using xformers attention in VAE")
|
|
||||||
self.optimized_attention = xformers_attention
|
|
||||||
elif model_management.pytorch_attention_enabled():
|
|
||||||
logging.info("Using pytorch attention in VAE")
|
|
||||||
self.optimized_attention = pytorch_attention
|
|
||||||
else:
|
|
||||||
logging.info("Using split attention in VAE")
|
|
||||||
self.optimized_attention = normal_attention
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ import math
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Optional, NamedTuple, List, Protocol
|
from typing import Optional, NamedTuple, List, Protocol
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing import Optional, NamedTuple, List
|
from typing import Optional, NamedTuple, List
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ import comfy.ldm.audio.embedders
|
|||||||
import comfy.ldm.flux.model
|
import comfy.ldm.flux.model
|
||||||
import comfy.ldm.lightricks.model
|
import comfy.ldm.lightricks.model
|
||||||
import comfy.ldm.hunyuan_video.model
|
import comfy.ldm.hunyuan_video.model
|
||||||
|
import comfy.ldm.cosmos.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@@ -147,7 +148,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = self.model_sampling.timestep(t).float()
|
t = self.model_sampling.timestep(t).float()
|
||||||
context = context.to(dtype)
|
if context is not None:
|
||||||
|
context = context.to(dtype)
|
||||||
|
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
@@ -188,9 +191,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
if len(denoise_mask.shape) == len(noise.shape):
|
if len(denoise_mask.shape) == len(noise.shape):
|
||||||
denoise_mask = denoise_mask[:,:1]
|
denoise_mask = denoise_mask[:, :1]
|
||||||
|
|
||||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
num_dim = noise.ndim - 2
|
||||||
|
denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:]))
|
||||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||||
@@ -200,12 +204,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(denoise_mask.to(device))
|
cond_concat.append(denoise_mask.to(device))
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
|
||||||
|
elif ck == "mask_inverted":
|
||||||
|
cond_concat.append(1.0 - denoise_mask.to(device))
|
||||||
else:
|
else:
|
||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
cond_concat.append(torch.ones_like(noise)[:, :1])
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||||
|
elif ck == "mask_inverted":
|
||||||
|
cond_concat.append(torch.zeros_like(noise)[:, :1])
|
||||||
data = torch.cat(cond_concat, dim=1)
|
data = torch.cat(cond_concat, dim=1)
|
||||||
return data
|
return data
|
||||||
return None
|
return None
|
||||||
@@ -293,6 +301,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
return blank_image
|
return blank_image
|
||||||
self.blank_inpaint_image_like = blank_inpaint_image_like
|
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
@@ -540,6 +551,10 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
|
|
||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class IP2P:
|
class IP2P:
|
||||||
@@ -797,7 +812,10 @@ class Flux(BaseModel):
|
|||||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
|
||||||
|
guidance = kwargs.get("guidance", 3.5)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class GenmoMochi(BaseModel):
|
class GenmoMochi(BaseModel):
|
||||||
@@ -854,5 +872,35 @@ class HunyuanVideo(BaseModel):
|
|||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
|
|
||||||
|
guidance = kwargs.get("guidance", 6.0)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class CosmosVideo(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
if self.image_to_video:
|
||||||
|
self.concat_keys = ("mask_inverted",)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
|
||||||
|
sigma_noise_augmentation = 0 #TODO
|
||||||
|
if sigma_noise_augmentation != 0:
|
||||||
|
latent_image = latent_image + noise
|
||||||
|
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||||
|
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||||
|
|||||||
@@ -239,6 +239,51 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["micro_condition"] = False
|
dit_config["micro_condition"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "cosmos"
|
||||||
|
dit_config["max_img_h"] = 240
|
||||||
|
dit_config["max_img_w"] = 240
|
||||||
|
dit_config["max_frames"] = 128
|
||||||
|
concat_padding_mask = True
|
||||||
|
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||||
|
dit_config["out_channels"] = 16
|
||||||
|
dit_config["patch_spatial"] = 2
|
||||||
|
dit_config["patch_temporal"] = 1
|
||||||
|
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["block_config"] = "FA-CA-MLP"
|
||||||
|
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||||
|
dit_config["pos_emb_cls"] = "rope3d"
|
||||||
|
dit_config["pos_emb_learnable"] = False
|
||||||
|
dit_config["pos_emb_interpolation"] = "crop"
|
||||||
|
dit_config["block_x_format"] = "THWBD"
|
||||||
|
dit_config["affline_emb_norm"] = True
|
||||||
|
dit_config["use_adaln_lora"] = True
|
||||||
|
dit_config["adaln_lora_dim"] = 256
|
||||||
|
|
||||||
|
if dit_config["model_channels"] == 4096:
|
||||||
|
# 7B
|
||||||
|
dit_config["num_blocks"] = 28
|
||||||
|
dit_config["num_heads"] = 32
|
||||||
|
dit_config["extra_per_block_abs_pos_emb"] = True
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
|
else: # 5120
|
||||||
|
# 14B
|
||||||
|
dit_config["num_blocks"] = 36
|
||||||
|
dit_config["num_heads"] = 40
|
||||||
|
dit_config["extra_per_block_abs_pos_emb"] = True
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_h_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_w_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -393,6 +438,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||||
"model.model.", #audio models
|
"model.model.", #audio models
|
||||||
|
"net.", #cosmos
|
||||||
]
|
]
|
||||||
counts = {k: 0 for k in candidates}
|
counts = {k: 0 for k in candidates}
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
|
|||||||
@@ -86,6 +86,13 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
_ = torch.npu.device_count()
|
||||||
|
npu_available = torch.npu.is_available()
|
||||||
|
except:
|
||||||
|
npu_available = False
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
cpu_state = CPUState.CPU
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
@@ -97,6 +104,12 @@ def is_intel_xpu():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_ascend_npu():
|
||||||
|
global npu_available
|
||||||
|
if npu_available:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@@ -110,6 +123,8 @@ def get_torch_device():
|
|||||||
else:
|
else:
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return torch.device("xpu", torch.xpu.current_device())
|
return torch.device("xpu", torch.xpu.current_device())
|
||||||
|
elif is_ascend_npu():
|
||||||
|
return torch.device("npu", torch.npu.current_device())
|
||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
@@ -130,6 +145,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_total_torch = mem_reserved
|
mem_total_torch = mem_reserved
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
|
elif is_ascend_npu():
|
||||||
|
stats = torch.npu.memory_stats(dev)
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
_, mem_total_npu = torch.npu.mem_get_info(dev)
|
||||||
|
mem_total_torch = mem_reserved
|
||||||
|
mem_total = mem_total_npu
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@@ -209,7 +230,7 @@ try:
|
|||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if is_intel_xpu():
|
if is_intel_xpu() or is_ascend_npu():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
except:
|
||||||
@@ -274,6 +295,8 @@ def get_torch_device_name(device):
|
|||||||
return "{}".format(device.type)
|
return "{}".format(device.type)
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||||
|
elif is_ascend_npu():
|
||||||
|
return "{} {}".format(device, torch.npu.get_device_name(device))
|
||||||
else:
|
else:
|
||||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
@@ -860,6 +883,8 @@ def xformers_enabled():
|
|||||||
return False
|
return False
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return False
|
return False
|
||||||
|
if is_ascend_npu():
|
||||||
|
return False
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILABLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
@@ -884,6 +909,8 @@ def pytorch_attention_flash_attention():
|
|||||||
return True
|
return True
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
if is_ascend_npu():
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def mac_version():
|
def mac_version():
|
||||||
@@ -923,6 +950,13 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||||
mem_free_total = mem_free_xpu + mem_free_torch
|
mem_free_total = mem_free_xpu + mem_free_torch
|
||||||
|
elif is_ascend_npu():
|
||||||
|
stats = torch.npu.memory_stats(dev)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_npu, _ = torch.npu.mem_get_info(dev)
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_npu + mem_free_torch
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@@ -984,6 +1018,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_ascend_npu():
|
||||||
|
return True
|
||||||
|
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -1081,19 +1118,16 @@ def soft_empty_cache(force=False):
|
|||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
|
elif is_ascend_npu():
|
||||||
|
torch.npu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.ipc_collect()
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
|
||||||
logging.warning("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
|
||||||
return weight
|
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ class ModelPatcher:
|
|||||||
self.injections: dict[str, list[PatcherInjection]] = {}
|
self.injections: dict[str, list[PatcherInjection]] = {}
|
||||||
|
|
||||||
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
||||||
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
|
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
|
||||||
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
||||||
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
||||||
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
||||||
@@ -282,7 +282,7 @@ class ModelPatcher:
|
|||||||
n.injections[k] = i.copy()
|
n.injections[k] = i.copy()
|
||||||
# hooks
|
# hooks
|
||||||
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
||||||
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
|
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
|
||||||
for group in self.cached_hook_patches:
|
for group in self.cached_hook_patches:
|
||||||
n.cached_hook_patches[group] = {}
|
n.cached_hook_patches[group] = {}
|
||||||
for k in self.cached_hook_patches[group]:
|
for k in self.cached_hook_patches[group]:
|
||||||
@@ -402,7 +402,20 @@ class ModelPatcher:
|
|||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
def get_model_object(self, name):
|
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||||
|
"""Retrieves a nested attribute from an object using dot notation considering
|
||||||
|
object patches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the requested attribute
|
||||||
|
|
||||||
|
Example:
|
||||||
|
patcher = ModelPatcher()
|
||||||
|
weight = patcher.get_model_object("layer1.conv.weight")
|
||||||
|
"""
|
||||||
if name in self.object_patches:
|
if name in self.object_patches:
|
||||||
return self.object_patches[name]
|
return self.object_patches[name]
|
||||||
else:
|
else:
|
||||||
@@ -842,6 +855,9 @@ class ModelPatcher:
|
|||||||
if key in self.injections:
|
if key in self.injections:
|
||||||
self.injections.pop(key)
|
self.injections.pop(key)
|
||||||
|
|
||||||
|
def get_injections(self, key: str):
|
||||||
|
return self.injections.get(key, None)
|
||||||
|
|
||||||
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
||||||
self.additional_models[key] = models
|
self.additional_models[key] = models
|
||||||
|
|
||||||
@@ -912,18 +928,19 @@ class ModelPatcher:
|
|||||||
callback(self, timestep)
|
callback(self, timestep)
|
||||||
|
|
||||||
def restore_hook_patches(self):
|
def restore_hook_patches(self):
|
||||||
if len(self.hook_patches_backup) > 0:
|
if self.hook_patches_backup is not None:
|
||||||
self.hook_patches = self.hook_patches_backup
|
self.hook_patches = self.hook_patches_backup
|
||||||
self.hook_patches_backup = {}
|
self.hook_patches_backup = None
|
||||||
|
|
||||||
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||||
self.hook_mode = hook_mode
|
self.hook_mode = hook_mode
|
||||||
|
|
||||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||||
curr_t = t[0]
|
curr_t = t[0]
|
||||||
reset_current_hooks = False
|
reset_current_hooks = False
|
||||||
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
for hook in hook_group.hooks:
|
for hook in hook_group.hooks:
|
||||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||||
# this will cause the weights to be recalculated when sampling
|
# this will cause the weights to be recalculated when sampling
|
||||||
if changed:
|
if changed:
|
||||||
@@ -939,25 +956,26 @@ class ModelPatcher:
|
|||||||
if reset_current_hooks:
|
if reset_current_hooks:
|
||||||
self.patch_hooks(None)
|
self.patch_hooks(None)
|
||||||
|
|
||||||
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||||
|
registered: comfy.hooks.HookGroup = None):
|
||||||
self.restore_hook_patches()
|
self.restore_hook_patches()
|
||||||
registered_hooks: list[comfy.hooks.Hook] = []
|
if registered is None:
|
||||||
# handle WrapperHooks, if model_options provided
|
registered = comfy.hooks.HookGroup()
|
||||||
if model_options is not None:
|
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
|
|
||||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
|
||||||
# handle WeightHooks
|
# handle WeightHooks
|
||||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
||||||
if hook.hook_ref not in self.hook_patches:
|
if hook.hook_ref not in self.hook_patches:
|
||||||
weight_hooks_to_register.append(hook)
|
weight_hooks_to_register.append(hook)
|
||||||
|
else:
|
||||||
|
registered.add(hook)
|
||||||
if len(weight_hooks_to_register) > 0:
|
if len(weight_hooks_to_register) > 0:
|
||||||
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
||||||
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
||||||
for hook in weight_hooks_to_register:
|
for hook in weight_hooks_to_register:
|
||||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
hook.add_hook_patches(self, model_options, target_dict, registered)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||||
callback(self, hooks_dict, target)
|
callback(self, hooks, target_dict, model_options, registered)
|
||||||
|
return registered
|
||||||
|
|
||||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
@@ -1008,11 +1026,11 @@ class ModelPatcher:
|
|||||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
||||||
# TODO: return transformer_options dict with any additions from hooks
|
# TODO: return transformer_options dict with any additions from hooks
|
||||||
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||||
return {}
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||||
self.patch_hooks(hooks=hooks)
|
self.patch_hooks(hooks=hooks)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||||
callback(self, hooks)
|
callback(self, hooks)
|
||||||
return {}
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
|||||||
@@ -25,9 +25,11 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
|||||||
return noises
|
return noises
|
||||||
|
|
||||||
def fix_empty_latent_channels(model, latent_image):
|
def fix_empty_latent_channels(model, latent_image):
|
||||||
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels
|
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||||
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1)
|
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||||
|
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
|
||||||
|
latent_image = latent_image.unsqueeze(2)
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||||
|
|||||||
@@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
|
|||||||
models += [c[model_type]]
|
models += [c[model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
|
||||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
if 'hooks' in c:
|
if 'hooks' in c:
|
||||||
for hook in c['hooks'].hooks:
|
for hook in c['hooks'].hooks:
|
||||||
hook: comfy.hooks.Hook
|
full_hooks.add(hook)
|
||||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
|
||||||
with_type[hook] = None
|
|
||||||
if 'control' in c:
|
if 'control' in c:
|
||||||
cnets.append(c['control'])
|
cnets.append(c['control'])
|
||||||
|
|
||||||
@@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
|
|||||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||||
if extra_hooks is not None:
|
if extra_hooks is not None:
|
||||||
for hook in extra_hooks.hooks:
|
for hook in extra_hooks.hooks:
|
||||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
full_hooks.add(hook)
|
||||||
with_type[hook] = None
|
|
||||||
|
|
||||||
return hooks_dict
|
return full_hooks
|
||||||
|
|
||||||
def convert_cond(cond):
|
def convert_cond(cond):
|
||||||
out = []
|
out = []
|
||||||
@@ -61,7 +58,6 @@ def convert_cond(cond):
|
|||||||
temp = c[1].copy()
|
temp = c[1].copy()
|
||||||
model_conds = temp.get("model_conds", {})
|
model_conds = temp.get("model_conds", {})
|
||||||
if c[0] is not None:
|
if c[0] is not None:
|
||||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
|
||||||
temp["cross_attn"] = c[0]
|
temp["cross_attn"] = c[0]
|
||||||
temp["model_conds"] = model_conds
|
temp["model_conds"] = model_conds
|
||||||
temp["uuid"] = uuid.uuid4()
|
temp["uuid"] = uuid.uuid4()
|
||||||
@@ -73,13 +69,11 @@ def get_additional_models(conds, dtype):
|
|||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
gligen = []
|
gligen = []
|
||||||
add_models = []
|
add_models = []
|
||||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
cnets += get_models_from_cond(conds[k], "control")
|
cnets += get_models_from_cond(conds[k], "control")
|
||||||
gligen += get_models_from_cond(conds[k], "gligen")
|
gligen += get_models_from_cond(conds[k], "gligen")
|
||||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
|
||||||
|
|
||||||
control_nets = set(cnets)
|
control_nets = set(cnets)
|
||||||
|
|
||||||
@@ -90,11 +84,20 @@ def get_additional_models(conds, dtype):
|
|||||||
inference_memory += m.inference_memory_requirements(dtype)
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1] for x in gligen]
|
||||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
models = control_models + gligen + add_models
|
||||||
models = control_models + gligen + add_models + hook_models
|
|
||||||
|
|
||||||
return models, inference_memory
|
return models, inference_memory
|
||||||
|
|
||||||
|
def get_additional_models_from_model_options(model_options: dict[str]=None):
|
||||||
|
"""loads additional models from registered AddModels hooks"""
|
||||||
|
models = []
|
||||||
|
if model_options is not None and "registered_hooks" in model_options:
|
||||||
|
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
|
||||||
|
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||||
|
hook: comfy.hooks.AdditionalModelsHook
|
||||||
|
models.extend(hook.models)
|
||||||
|
return models
|
||||||
|
|
||||||
def cleanup_additional_models(models):
|
def cleanup_additional_models(models):
|
||||||
"""cleanup additional models that were loaded"""
|
"""cleanup additional models that were loaded"""
|
||||||
for m in models:
|
for m in models:
|
||||||
@@ -102,9 +105,10 @@ def cleanup_additional_models(models):
|
|||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
real_model: 'BaseModel' = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||||
@@ -123,12 +127,35 @@ def cleanup_models(conds, models):
|
|||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||||
|
'''
|
||||||
|
Registers hooks from conds.
|
||||||
|
'''
|
||||||
# check for hooks in conds - if not registered, see if can be applied
|
# check for hooks in conds - if not registered, see if can be applied
|
||||||
hooks = {}
|
hooks = comfy.hooks.HookGroup()
|
||||||
for k in conds:
|
for k in conds:
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
get_hooks_from_cond(conds[k], hooks)
|
||||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||||
# register hooks on model/model_options
|
# begin registering hooks
|
||||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
registered = comfy.hooks.HookGroup()
|
||||||
|
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||||
|
# handle all TransformerOptionsHooks
|
||||||
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
||||||
|
hook: comfy.hooks.TransformerOptionsHook
|
||||||
|
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||||
|
# handle all AddModelsHooks
|
||||||
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||||
|
hook: comfy.hooks.AdditionalModelsHook
|
||||||
|
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||||
|
# handle all WeightHooks by registering on ModelPatcher
|
||||||
|
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
|
||||||
|
# add registered_hooks onto model_options for further reference
|
||||||
|
if len(registered) > 0:
|
||||||
|
model_options["registered_hooks"] = registered
|
||||||
|
# merge original wrappers and callbacks with hooked wrappers and callbacks
|
||||||
|
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
|
||||||
|
for wc_name in ["wrappers", "callbacks"]:
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||||
|
copy_dict1=False)
|
||||||
|
return to_load_options
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.model_base import BaseModel
|
from comfy.model_base import BaseModel
|
||||||
from comfy.controlnet import ControlBase
|
from comfy.controlnet import ControlBase
|
||||||
import torch
|
import torch
|
||||||
|
from functools import partial
|
||||||
import collections
|
import collections
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.samplers
|
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@@ -144,7 +144,7 @@ def cond_cat(c_list):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
|
||||||
# need to figure out remaining unmasked area for conds
|
# need to figure out remaining unmasked area for conds
|
||||||
default_mults = []
|
default_mults = []
|
||||||
for _ in default_conds:
|
for _ in default_conds:
|
||||||
@@ -177,13 +177,13 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
|||||||
cond = default_conds[i]
|
cond = default_conds[i]
|
||||||
for x in cond:
|
for x in cond:
|
||||||
# do get_area_and_mult to get all the expected values
|
# do get_area_and_mult to get all the expected values
|
||||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
# replace p's mult with calculated mult
|
# replace p's mult with calculated mult
|
||||||
p = p._replace(mult=mult)
|
p = p._replace(mult=mult)
|
||||||
if p.hooks is not None:
|
if p.hooks is not None:
|
||||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||||
hooked_to_run.setdefault(p.hooks, list())
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
hooked_to_run[p.hooks] += [(p, i)]
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
|
||||||
@@ -214,17 +214,17 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
default_c.append(x)
|
default_c.append(x)
|
||||||
has_default_conds = True
|
has_default_conds = True
|
||||||
continue
|
continue
|
||||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
if p.hooks is not None:
|
if p.hooks is not None:
|
||||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||||
hooked_to_run.setdefault(p.hooks, list())
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
hooked_to_run[p.hooks] += [(p, i)]
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
default_conds.append(default_c)
|
default_conds.append(default_c)
|
||||||
|
|
||||||
if has_default_conds:
|
if has_default_conds:
|
||||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
|
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
model.current_patcher.prepare_state(timestep)
|
model.current_patcher.prepare_state(timestep)
|
||||||
|
|
||||||
@@ -375,7 +375,7 @@ class KSamplerX0Inpaint:
|
|||||||
if "denoise_mask_function" in model_options:
|
if "denoise_mask_function" in model_options:
|
||||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
||||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out = out * denoise_mask + self.latent_image * latent_mask
|
out = out * denoise_mask + self.latent_image * latent_mask
|
||||||
@@ -467,6 +467,13 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
|
|||||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||||
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||||
|
|
||||||
|
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
|
||||||
|
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
|
||||||
|
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
|
||||||
|
sigmas = adj_idxs.new_zeros(n + 1)
|
||||||
|
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
|
||||||
|
return sigmas
|
||||||
|
|
||||||
def get_mask_aabb(masks):
|
def get_mask_aabb(masks):
|
||||||
if masks.numel() == 0:
|
if masks.numel() == 0:
|
||||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||||
@@ -679,7 +686,7 @@ class Sampler:
|
|||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis"]
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
@@ -802,6 +809,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
|||||||
for cond in conds_to_modify:
|
for cond in conds_to_modify:
|
||||||
cond['hooks'] = hooks
|
cond['hooks'] = hooks
|
||||||
|
|
||||||
|
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
|
||||||
|
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
|
||||||
|
HookGroups that have the same reference.'''
|
||||||
|
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
|
||||||
|
# if None were registered, make sure all hooks are cleaned from conds
|
||||||
|
if registered is None:
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
kk.pop('hooks', None)
|
||||||
|
return
|
||||||
|
# find conds that contain hooks to be replaced - group by common HookGroup refs
|
||||||
|
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||||
|
if hooks is not None:
|
||||||
|
if not hooks.is_subset_of(registered):
|
||||||
|
to_replace = hook_replacement.setdefault(hooks, [])
|
||||||
|
to_replace.append(kk)
|
||||||
|
# for each hook to replace, create a new proper HookGroup and assign to all common conds
|
||||||
|
for hooks, conds_to_modify in hook_replacement.items():
|
||||||
|
new_hooks = hooks.new_with_common_hooks(registered)
|
||||||
|
if len(new_hooks) == 0:
|
||||||
|
new_hooks = None
|
||||||
|
for kk in conds_to_modify:
|
||||||
|
kk['hooks'] = new_hooks
|
||||||
|
|
||||||
|
|
||||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||||
hooks_set = set()
|
hooks_set = set()
|
||||||
@@ -811,9 +845,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
|||||||
return len(hooks_set)
|
return len(hooks_set)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||||
|
'''
|
||||||
|
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
|
||||||
|
'''
|
||||||
|
if model_options is None:
|
||||||
|
return
|
||||||
|
to_load_options = model_options.get("to_load_options", None)
|
||||||
|
if to_load_options is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
casts = []
|
||||||
|
if device is not None:
|
||||||
|
casts.append(device)
|
||||||
|
if dtype is not None:
|
||||||
|
casts.append(dtype)
|
||||||
|
# if nothing to apply, do nothing
|
||||||
|
if len(casts) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# try to call .to on patches
|
||||||
|
if "patches" in to_load_options:
|
||||||
|
patches = to_load_options["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
patch_list[i] = patch_list[i].to(cast)
|
||||||
|
if "patches_replace" in to_load_options:
|
||||||
|
patches = to_load_options["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
patch_list[k] = patch_list[k].to(cast)
|
||||||
|
# try to call .to on any wrappers/callbacks
|
||||||
|
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||||
|
for wc_name in wrappers_and_callbacks:
|
||||||
|
if wc_name in to_load_options:
|
||||||
|
wc: dict[str, list] = to_load_options[wc_name]
|
||||||
|
for wc_dict in wc.values():
|
||||||
|
for wc_list in wc_dict.values():
|
||||||
|
for i in range(len(wc_list)):
|
||||||
|
if hasattr(wc_list[i], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
wc_list[i] = wc_list[i].to(cast)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher):
|
def __init__(self, model_patcher: ModelPatcher):
|
||||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
self.model_patcher = model_patcher
|
||||||
self.model_options = model_patcher.model_options
|
self.model_options = model_patcher.model_options
|
||||||
self.original_conds = {}
|
self.original_conds = {}
|
||||||
self.cfg = 1.0
|
self.cfg = 1.0
|
||||||
@@ -840,7 +923,9 @@ class CFGGuider:
|
|||||||
|
|
||||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||||
|
|
||||||
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||||
|
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||||
|
extra_args = {"model_options": extra_model_options, "seed": seed}
|
||||||
|
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
sampler.sample,
|
sampler.sample,
|
||||||
@@ -851,7 +936,7 @@ class CFGGuider:
|
|||||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
@@ -860,6 +945,7 @@ class CFGGuider:
|
|||||||
noise = noise.to(device)
|
noise = noise.to(device)
|
||||||
latent_image = latent_image.to(device)
|
latent_image = latent_image.to(device)
|
||||||
sigmas = sigmas.to(device)
|
sigmas = sigmas.to(device)
|
||||||
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
@@ -889,6 +975,7 @@ class CFGGuider:
|
|||||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||||
|
filter_registered_hooks_on_conds(self.conds, self.model_options)
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self.outer_sample,
|
self.outer_sample,
|
||||||
self,
|
self,
|
||||||
@@ -896,6 +983,7 @@ class CFGGuider:
|
|||||||
)
|
)
|
||||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
finally:
|
finally:
|
||||||
|
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||||
self.model_options = orig_model_options
|
self.model_options = orig_model_options
|
||||||
self.model_patcher.hook_mode = orig_hook_mode
|
self.model_patcher.hook_mode = orig_hook_mode
|
||||||
self.model_patcher.restore_hook_patches()
|
self.model_patcher.restore_hook_patches()
|
||||||
@@ -911,29 +999,37 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
|
||||||
|
|
||||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
|
||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
class SchedulerHandler(NamedTuple):
|
||||||
if scheduler_name == "karras":
|
handler: Callable[..., torch.Tensor]
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
# Boolean indicates whether to call the handler like:
|
||||||
elif scheduler_name == "exponential":
|
# scheduler_function(model_sampling, steps) or
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
# scheduler_function(n, sigma_min: float, sigma_max: float)
|
||||||
elif scheduler_name == "normal":
|
use_ms: bool = True
|
||||||
sigmas = normal_scheduler(model_sampling, steps)
|
|
||||||
elif scheduler_name == "simple":
|
SCHEDULER_HANDLERS = {
|
||||||
sigmas = simple_scheduler(model_sampling, steps)
|
"normal": SchedulerHandler(normal_scheduler),
|
||||||
elif scheduler_name == "ddim_uniform":
|
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
||||||
sigmas = ddim_scheduler(model_sampling, steps)
|
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
||||||
elif scheduler_name == "sgm_uniform":
|
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
"simple": SchedulerHandler(simple_scheduler),
|
||||||
elif scheduler_name == "beta":
|
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
||||||
sigmas = beta_scheduler(model_sampling, steps)
|
"beta": SchedulerHandler(beta_scheduler),
|
||||||
elif scheduler_name == "linear_quadratic":
|
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||||
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||||
else:
|
}
|
||||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
|
||||||
return sigmas
|
|
||||||
|
def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
|
||||||
|
handler = SCHEDULER_HANDLERS.get(scheduler_name)
|
||||||
|
if handler is None:
|
||||||
|
err = f"error invalid scheduler {scheduler_name}"
|
||||||
|
logging.error(err)
|
||||||
|
raise ValueError(err)
|
||||||
|
if handler.use_ms:
|
||||||
|
return handler.handler(model_sampling, steps)
|
||||||
|
return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||||
|
|
||||||
def sampler_object(name):
|
def sampler_object(name):
|
||||||
if name == "uni_pc":
|
if name == "uni_pc":
|
||||||
|
|||||||
31
comfy/sd.py
31
comfy/sd.py
@@ -11,6 +11,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
|
|||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
import comfy.ldm.genmo.vae.model
|
import comfy.ldm.genmo.vae.model
|
||||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||||
|
import comfy.ldm.cosmos.vae
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@@ -34,6 +35,7 @@ import comfy.text_encoders.long_clipl
|
|||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
import comfy.text_encoders.lt
|
import comfy.text_encoders.lt
|
||||||
import comfy.text_encoders.hunyuan_video
|
import comfy.text_encoders.hunyuan_video
|
||||||
|
import comfy.text_encoders.cosmos
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@@ -111,7 +113,7 @@ class CLIP:
|
|||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@@ -376,6 +378,19 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||||
|
self.upscale_index_formula = (8, 8, 8)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 8, 8)
|
||||||
|
self.downscale_index_formula = (8, 8, 8)
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.latent_channels = 16
|
||||||
|
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
|
||||||
|
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
|
||||||
|
#TODO: these values are a bit off because this is not a standard VAE
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@@ -519,7 +534,7 @@ class VAE:
|
|||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if self.latent_dim == 3:
|
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
@@ -641,6 +656,7 @@ class CLIPType(Enum):
|
|||||||
LTXV = 8
|
LTXV = 8
|
||||||
HUNYUAN_VIDEO = 9
|
HUNYUAN_VIDEO = 9
|
||||||
PIXART = 10
|
PIXART = 10
|
||||||
|
COSMOS = 11
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@@ -658,6 +674,7 @@ class TEModel(Enum):
|
|||||||
T5_XL = 5
|
T5_XL = 5
|
||||||
T5_BASE = 6
|
T5_BASE = 6
|
||||||
LLAMA3_8 = 7
|
LLAMA3_8 = 7
|
||||||
|
T5_XXL_OLD = 8
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@@ -672,6 +689,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.T5_XXL
|
return TEModel.T5_XXL
|
||||||
elif weight.shape[-1] == 2048:
|
elif weight.shape[-1] == 2048:
|
||||||
return TEModel.T5_XL
|
return TEModel.T5_XL
|
||||||
|
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||||
|
return TEModel.T5_XXL_OLD
|
||||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
@@ -681,9 +700,10 @@ def detect_te_model(sd):
|
|||||||
|
|
||||||
def t5xxl_detect(clip_data):
|
def t5xxl_detect(clip_data):
|
||||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||||
|
weight_name_old = "encoder.block.23.layer.1.DenseReluDense.wi.weight"
|
||||||
|
|
||||||
for sd in clip_data:
|
for sd in clip_data:
|
||||||
if weight_name in sd:
|
if weight_name in sd or weight_name_old in sd:
|
||||||
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
@@ -740,6 +760,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
|
elif te_model == TEModel.T5_XXL_OLD:
|
||||||
|
clip_target.clip = comfy.text_encoders.cosmos.te(**t5xxl_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.cosmos.CosmosT5Tokenizer
|
||||||
elif te_model == TEModel.T5_XL:
|
elif te_model == TEModel.T5_XL:
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
@@ -898,7 +921,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_model:
|
if output_model:
|
||||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
logging.info("loaded straight to GPU")
|
logging.info("loaded diffusion model directly to GPU")
|
||||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||||
|
|
||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|||||||
@@ -388,13 +388,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
if 'weights_only' in torch.load.__code__.co_varnames:
|
try:
|
||||||
try:
|
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
except:
|
||||||
except:
|
embed_out = safe_load_embed_zip(embed_path)
|
||||||
embed_out = safe_load_embed_zip(embed_path)
|
|
||||||
else:
|
|
||||||
embed = torch.load(embed_path, map_location="cpu")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import comfy.text_encoders.flux
|
|||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
import comfy.text_encoders.lt
|
import comfy.text_encoders.lt
|
||||||
import comfy.text_encoders.hunyuan_video
|
import comfy.text_encoders.hunyuan_video
|
||||||
|
import comfy.text_encoders.cosmos
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@@ -608,6 +609,8 @@ class PixArtAlpha(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
|
memory_usage_factor = 0.5
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@@ -640,6 +643,8 @@ class HunyuanDiT(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
|
memory_usage_factor = 1.3
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@@ -783,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.HunyuanVideo
|
latent_format = latent_formats.HunyuanVideo
|
||||||
|
|
||||||
memory_usage_factor = 2.0 #TODO
|
memory_usage_factor = 1.8 #TODO
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@@ -819,6 +824,47 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
|
class CosmosT2V(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "cosmos",
|
||||||
|
"in_channels": 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"sigma_data": 0.5,
|
||||||
|
"sigma_max": 80.0,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Cosmos1CV8x8x8
|
||||||
|
|
||||||
|
memory_usage_factor = 1.6 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.CosmosVideo(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||||
|
|
||||||
|
class CosmosI2V(CosmosT2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "cosmos",
|
||||||
|
"in_channels": 17,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
42
comfy/text_encoders/cosmos.py
Normal file
42
comfy/text_encoders/cosmos.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.t5
|
||||||
|
import os
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||||
|
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||||
|
if t5xxl_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
class CosmosT5XXL(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
|
||||||
|
|
||||||
|
|
||||||
|
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
|
class CosmosTEModel_(CosmosT5XXL):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
if dtype is None:
|
||||||
|
dtype = dtype_t5
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return CosmosTEModel_
|
||||||
@@ -227,8 +227,9 @@ class T5(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_layers = config_dict["num_layers"]
|
self.num_layers = config_dict["num_layers"]
|
||||||
model_dim = config_dict["d_model"]
|
model_dim = config_dict["d_model"]
|
||||||
|
inner_dim = config_dict["d_kv"] * config_dict["num_heads"]
|
||||||
|
|
||||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
self.encoder = T5Stack(self.num_layers, model_dim, inner_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
|
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|||||||
22
comfy/text_encoders/t5_old_config_xxl.json
Normal file
22
comfy/text_encoders/t5_old_config_xxl.json
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"d_ff": 65536,
|
||||||
|
"d_kv": 128,
|
||||||
|
"d_model": 1024,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"dense_act_fn": "relu",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": false,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 128,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
||||||
@@ -29,17 +29,30 @@ import itertools
|
|||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
ALWAYS_SAFE_LOAD = False
|
||||||
|
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||||
|
class ModelCheckpoint:
|
||||||
|
pass
|
||||||
|
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||||
|
|
||||||
|
from numpy.core.multiarray import scalar
|
||||||
|
from numpy import dtype
|
||||||
|
from numpy.dtypes import Float64DType
|
||||||
|
from _codecs import encode
|
||||||
|
|
||||||
|
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||||
|
ALWAYS_SAFE_LOAD = True
|
||||||
|
logging.info("Checkpoint files will always be loaded safely.")
|
||||||
|
else:
|
||||||
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||||
else:
|
else:
|
||||||
if safe_load:
|
if safe_load or ALWAYS_SAFE_LOAD:
|
||||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
|
||||||
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
|
||||||
safe_load = False
|
|
||||||
if safe_load:
|
|
||||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||||
@@ -693,7 +706,25 @@ def copy_to_param(obj, attr, value):
|
|||||||
prev = getattr(obj, attrs[-1])
|
prev = getattr(obj, attrs[-1])
|
||||||
prev.data.copy_(value)
|
prev.data.copy_(value)
|
||||||
|
|
||||||
def get_attr(obj, attr):
|
def get_attr(obj, attr: str):
|
||||||
|
"""Retrieves a nested attribute from an object using dot notation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to get the attribute from
|
||||||
|
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the requested attribute
|
||||||
|
|
||||||
|
Example:
|
||||||
|
model = MyModel()
|
||||||
|
weight = get_attr(model, "layer1.conv.weight")
|
||||||
|
# Equivalent to: model.layer1.conv.weight
|
||||||
|
|
||||||
|
Important:
|
||||||
|
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
|
||||||
|
accessing nested model objects under `ModelPatcher.model`.
|
||||||
|
"""
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
for name in attrs:
|
for name in attrs:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
@@ -893,7 +924,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
|
|
||||||
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
for it in itertools.product(*positions):
|
for it in itertools.product(*positions):
|
||||||
s_in = s
|
s_in = s
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ class DynamicPrompt:
|
|||||||
def get_original_prompt(self):
|
def get_original_prompt(self):
|
||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
def get_input_info(class_def, input_name):
|
def get_input_info(class_def, input_name, valid_inputs=None):
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
||||||
input_info = None
|
input_info = None
|
||||||
input_category = None
|
input_category = None
|
||||||
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||||
|
|||||||
82
comfy_extras/nodes_cosmos.py
Normal file
82
comfy_extras/nodes_cosmos.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import nodes
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyCosmosLatentVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
|
||||||
|
def generate(self, width, height, length, batch_size=1):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
return ({"samples": latent}, )
|
||||||
|
|
||||||
|
|
||||||
|
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
||||||
|
pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
pixel_len = min(pixels.shape[0], length)
|
||||||
|
padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
|
||||||
|
padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
|
||||||
|
padded_pixels[:pixel_len] = pixels[:pixel_len]
|
||||||
|
latent_len = ((pixel_len - 1) // 8) + 1
|
||||||
|
latent_temp = vae.encode(padded_pixels)
|
||||||
|
return latent_temp[:, :, :latent_len]
|
||||||
|
|
||||||
|
|
||||||
|
class CosmosImageToVideoLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"start_image": ("IMAGE", ),
|
||||||
|
"end_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/inpaint"
|
||||||
|
|
||||||
|
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
|
||||||
|
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
if start_image is None and end_image is None:
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (out_latent,)
|
||||||
|
|
||||||
|
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
||||||
|
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||||
|
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||||
|
|
||||||
|
if end_image is not None:
|
||||||
|
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
||||||
|
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
||||||
|
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||||
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
|
return (out_latent,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
||||||
|
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
|
||||||
|
}
|
||||||
@@ -231,6 +231,24 @@ class FlipSigmas:
|
|||||||
sigmas[0] = 0.0001
|
sigmas[0] = 0.0001
|
||||||
return (sigmas,)
|
return (sigmas,)
|
||||||
|
|
||||||
|
class SetFirstSigma:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"sigmas": ("SIGMAS", ),
|
||||||
|
"sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||||
|
|
||||||
|
FUNCTION = "set_first_sigma"
|
||||||
|
|
||||||
|
def set_first_sigma(self, sigmas, sigma):
|
||||||
|
sigmas = sigmas.clone()
|
||||||
|
sigmas[0] = sigma
|
||||||
|
return (sigmas, )
|
||||||
|
|
||||||
class KSamplerSelect:
|
class KSamplerSelect:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -710,6 +728,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SplitSigmas": SplitSigmas,
|
"SplitSigmas": SplitSigmas,
|
||||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||||
"FlipSigmas": FlipSigmas,
|
"FlipSigmas": FlipSigmas,
|
||||||
|
"SetFirstSigma": SetFirstSigma,
|
||||||
|
|
||||||
"CFGGuider": CFGGuider,
|
"CFGGuider": CFGGuider,
|
||||||
"DualCFGGuider": DualCFGGuider,
|
"DualCFGGuider": DualCFGGuider,
|
||||||
|
|||||||
@@ -38,7 +38,26 @@ class FluxGuidance:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
|
class FluxDisableGuidance:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"conditioning": ("CONDITIONING", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning/flux"
|
||||||
|
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
|
||||||
|
|
||||||
|
def append(self, conditioning):
|
||||||
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
||||||
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||||
"FluxGuidance": FluxGuidance,
|
"FluxGuidance": FluxGuidance,
|
||||||
|
"FluxDisableGuidance": FluxDisableGuidance,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ class SetClipHooks:
|
|||||||
CATEGORY = "advanced/hooks/clip"
|
CATEGORY = "advanced/hooks/clip"
|
||||||
FUNCTION = "apply_hooks"
|
FUNCTION = "apply_hooks"
|
||||||
|
|
||||||
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
clip = clip.clone()
|
clip = clip.clone()
|
||||||
if apply_to_conds:
|
if apply_to_conds:
|
||||||
@@ -255,7 +255,7 @@ class SetClipHooks:
|
|||||||
clip.use_clip_schedule = schedule_clip
|
clip.use_clip_schedule = schedule_clip
|
||||||
if not clip.use_clip_schedule:
|
if not clip.use_clip_schedule:
|
||||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class ConditioningTimestepsRange:
|
class ConditioningTimestepsRange:
|
||||||
|
|||||||
@@ -2,10 +2,14 @@ import comfy.utils
|
|||||||
import comfy_extras.nodes_post_processing
|
import comfy_extras.nodes_post_processing
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent):
|
|
||||||
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
if latent.shape[1:] != target_shape[1:]:
|
if latent.shape[1:] != target_shape[1:]:
|
||||||
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
|
latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center")
|
||||||
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
if repeat_batch:
|
||||||
|
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
||||||
|
else:
|
||||||
|
return latent
|
||||||
|
|
||||||
|
|
||||||
class LatentAdd:
|
class LatentAdd:
|
||||||
@@ -116,8 +120,7 @@ class LatentBatch:
|
|||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
s2 = samples2["samples"]
|
s2 = samples2["samples"]
|
||||||
|
|
||||||
if s1.shape[1:] != s2.shape[1:]:
|
s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
|
||||||
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
|
|
||||||
s = torch.cat((s1, s2), dim=0)
|
s = torch.cat((s1, s2), dim=0)
|
||||||
samples_out["samples"] = s
|
samples_out["samples"] = s
|
||||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||||
|
|||||||
@@ -19,13 +19,11 @@ class Load3D():
|
|||||||
"image": ("LOAD_3D", {}),
|
"image": ("LOAD_3D", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"show_grid": ([True, False],),
|
|
||||||
"camera_type": (["perspective", "orthographic"],),
|
|
||||||
"view": (["front", "right", "top", "isometric"],),
|
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
"material": (["original", "normal", "wireframe", "depth"],),
|
||||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||||
|
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||||
@@ -37,13 +35,22 @@ class Load3D():
|
|||||||
CATEGORY = "3d"
|
CATEGORY = "3d"
|
||||||
|
|
||||||
def process(self, model_file, image, **kwargs):
|
def process(self, model_file, image, **kwargs):
|
||||||
imagepath = folder_paths.get_annotated_filepath(image)
|
if isinstance(image, dict):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||||
|
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||||
|
|
||||||
load_image_node = nodes.LoadImage()
|
load_image_node = nodes.LoadImage()
|
||||||
|
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||||
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
|
|
||||||
output_image, output_mask = load_image_node.load_image(image=imagepath)
|
return output_image, output_mask, model_file,
|
||||||
|
else:
|
||||||
return output_image, output_mask, model_file,
|
# to avoid the format is not dict which will happen the FE code is not compatibility to core,
|
||||||
|
# we need to this to double-check, it can be removed after merged FE into the core
|
||||||
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
|
load_image_node = nodes.LoadImage()
|
||||||
|
output_image, output_mask = load_image_node.load_image(image=image_path)
|
||||||
|
return output_image, output_mask, model_file,
|
||||||
|
|
||||||
class Load3DAnimation():
|
class Load3DAnimation():
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -59,14 +66,12 @@ class Load3DAnimation():
|
|||||||
"image": ("LOAD_3D_ANIMATION", {}),
|
"image": ("LOAD_3D_ANIMATION", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"show_grid": ([True, False],),
|
|
||||||
"camera_type": (["perspective", "orthographic"],),
|
|
||||||
"view": (["front", "right", "top", "isometric"],),
|
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
"material": (["original", "normal", "wireframe", "depth"],),
|
||||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||||
"animation_speed": (["0.1", "0.5", "1", "1.5", "2"], {"default": "1"}),
|
"animation_speed": (["0.1", "0.5", "1", "1.5", "2"], {"default": "1"}),
|
||||||
|
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||||
@@ -78,26 +83,31 @@ class Load3DAnimation():
|
|||||||
CATEGORY = "3d"
|
CATEGORY = "3d"
|
||||||
|
|
||||||
def process(self, model_file, image, **kwargs):
|
def process(self, model_file, image, **kwargs):
|
||||||
imagepath = folder_paths.get_annotated_filepath(image)
|
if isinstance(image, dict):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||||
|
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||||
|
|
||||||
load_image_node = nodes.LoadImage()
|
load_image_node = nodes.LoadImage()
|
||||||
|
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||||
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
|
|
||||||
output_image, output_mask = load_image_node.load_image(image=imagepath)
|
return output_image, output_mask, model_file,
|
||||||
|
else:
|
||||||
return output_image, output_mask, model_file,
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
|
load_image_node = nodes.LoadImage()
|
||||||
|
output_image, output_mask = load_image_node.load_image(image=image_path)
|
||||||
|
return output_image, output_mask, model_file,
|
||||||
|
|
||||||
class Preview3D():
|
class Preview3D():
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
"show_grid": ([True, False],),
|
|
||||||
"camera_type": (["perspective", "orthographic"],),
|
|
||||||
"view": (["front", "right", "top", "isometric"],),
|
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
"material": (["original", "normal", "wireframe", "depth"],),
|
||||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||||
|
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
"sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
|
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
|
||||||
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
}}
|
}}
|
||||||
@@ -206,6 +206,9 @@ class ModelSamplingContinuousEDM:
|
|||||||
sigma_data = 1.0
|
sigma_data = 1.0
|
||||||
if sampling == "eps":
|
if sampling == "eps":
|
||||||
sampling_type = comfy.model_sampling.EPS
|
sampling_type = comfy.model_sampling.EPS
|
||||||
|
elif sampling == "edm":
|
||||||
|
sampling_type = comfy.model_sampling.EDM
|
||||||
|
sigma_data = 0.5
|
||||||
elif sampling == "v_prediction":
|
elif sampling == "v_prediction":
|
||||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||||
elif sampling == "edm_playground_v2.5":
|
elif sampling == "edm_playground_v2.5":
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
|||||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
hsy, wsx = h // sy, w // sx
|
hsy, wsx = h // sy, w // sx
|
||||||
|
|
||||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||||
|
|||||||
3
comfyui_version.py
Normal file
3
comfyui_version.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# This file is automatically generated by the build process when version is
|
||||||
|
# updated in pyproject.toml.
|
||||||
|
__version__ = "0.3.12"
|
||||||
@@ -93,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x)
|
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@@ -196,7 +196,6 @@ def merge_result_data(results, obj):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
uis = []
|
uis = []
|
||||||
subgraph_results = []
|
subgraph_results = []
|
||||||
@@ -556,7 +555,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetenso
|
|||||||
|
|
||||||
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
|
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
|
||||||
|
|
||||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
env_base_path = os.environ.get("COMFYUI_FOLDERS_BASE_PATH")
|
||||||
|
base_path = os.path.dirname(os.path.realpath(__file__)) if env_base_path is None else env_base_path
|
||||||
models_dir = os.path.join(base_path, "models")
|
models_dir = os.path.join(base_path, "models")
|
||||||
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
|
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
|
||||||
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
|
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
|
||||||
|
|||||||
6
main.py
6
main.py
@@ -17,7 +17,7 @@ if __name__ == "__main__":
|
|||||||
os.environ['DO_NOT_TRACK'] = '1'
|
os.environ['DO_NOT_TRACK'] = '1'
|
||||||
|
|
||||||
|
|
||||||
setup_logger(log_level=args.verbose)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
def apply_custom_paths():
|
def apply_custom_paths():
|
||||||
# extra model paths
|
# extra model paths
|
||||||
@@ -211,7 +211,9 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
|||||||
addresses = []
|
addresses = []
|
||||||
for addr in address.split(","):
|
for addr in address.split(","):
|
||||||
addresses.append((addr, port))
|
addresses.append((addr, port))
|
||||||
await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())
|
await asyncio.gather(
|
||||||
|
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server_instance):
|
def hijack_progress(server_instance):
|
||||||
|
|||||||
36
nodes.py
36
nodes.py
@@ -309,7 +309,7 @@ class VAEDecodeTiled:
|
|||||||
temporal_compression = vae.temporal_compression_decode()
|
temporal_compression = vae.temporal_compression_decode()
|
||||||
if temporal_compression is not None:
|
if temporal_compression is not None:
|
||||||
temporal_size = max(2, temporal_size // temporal_compression)
|
temporal_size = max(2, temporal_size // temporal_compression)
|
||||||
temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression)
|
temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
|
||||||
else:
|
else:
|
||||||
temporal_size = None
|
temporal_size = None
|
||||||
temporal_overlap = None
|
temporal_overlap = None
|
||||||
@@ -912,16 +912,19 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
if type == "stable_cascade":
|
if type == "stable_cascade":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||||
elif type == "sd3":
|
elif type == "sd3":
|
||||||
@@ -934,11 +937,17 @@ class CLIPLoader:
|
|||||||
clip_type = comfy.sd.CLIPType.LTXV
|
clip_type = comfy.sd.CLIPType.LTXV
|
||||||
elif type == "pixart":
|
elif type == "pixart":
|
||||||
clip_type = comfy.sd.CLIPType.PIXART
|
clip_type = comfy.sd.CLIPType.PIXART
|
||||||
|
elif type == "cosmos":
|
||||||
|
clip_type = comfy.sd.CLIPType.COSMOS
|
||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
|
model_options = {}
|
||||||
|
if device == "cpu":
|
||||||
|
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class DualCLIPLoader:
|
class DualCLIPLoader:
|
||||||
@@ -947,6 +956,9 @@ class DualCLIPLoader:
|
|||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@@ -955,7 +967,7 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
if type == "sdxl":
|
if type == "sdxl":
|
||||||
@@ -967,7 +979,11 @@ class DualCLIPLoader:
|
|||||||
elif type == "hunyuan_video":
|
elif type == "hunyuan_video":
|
||||||
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
|
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
|
||||||
|
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
model_options = {}
|
||||||
|
if device == "cpu":
|
||||||
|
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||||
|
|
||||||
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class CLIPVisionLoader:
|
class CLIPVisionLoader:
|
||||||
@@ -2047,6 +2063,9 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
|
|
||||||
EXTENSION_WEB_DIRS = {}
|
EXTENSION_WEB_DIRS = {}
|
||||||
|
|
||||||
|
# Dictionary of successfully loaded module names and associated directories.
|
||||||
|
LOADED_MODULE_DIRS = {}
|
||||||
|
|
||||||
|
|
||||||
def get_module_name(module_path: str) -> str:
|
def get_module_name(module_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -2088,6 +2107,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
|||||||
sys.modules[module_name] = module
|
sys.modules[module_name] = module
|
||||||
module_spec.loader.exec_module(module)
|
module_spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||||
|
|
||||||
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
|
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
|
||||||
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
|
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
|
||||||
if os.path.isdir(web_dir):
|
if os.path.isdir(web_dir):
|
||||||
@@ -2206,6 +2227,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_lt.py",
|
"nodes_lt.py",
|
||||||
"nodes_hooks.py",
|
"nodes_hooks.py",
|
||||||
"nodes_load_3d.py",
|
"nodes_load_3d.py",
|
||||||
|
"nodes_cosmos.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
23
pyproject.toml
Normal file
23
pyproject.toml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
[project]
|
||||||
|
name = "ComfyUI"
|
||||||
|
version = "0.3.12"
|
||||||
|
readme = "README.md"
|
||||||
|
license = { file = "LICENSE" }
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
homepage = "https://www.comfy.org/"
|
||||||
|
repository = "https://github.com/comfyanonymous/ComfyUI"
|
||||||
|
documentation = "https://docs.comfy.org/"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
lint.select = [
|
||||||
|
"S307", # suspicious-eval-usage
|
||||||
|
"S102", # exec
|
||||||
|
"T", # print-usage
|
||||||
|
"W",
|
||||||
|
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||||
|
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||||
|
"F",
|
||||||
|
]
|
||||||
|
exclude = ["*.ipynb"]
|
||||||
@@ -2,6 +2,7 @@ torch
|
|||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
|
numpy>=1.25.0
|
||||||
einops
|
einops
|
||||||
transformers>=4.28.1
|
transformers>=4.28.1
|
||||||
tokenizers>=0.13.3
|
tokenizers>=0.13.3
|
||||||
|
|||||||
13
ruff.toml
13
ruff.toml
@@ -1,13 +0,0 @@
|
|||||||
# Disable all rules by default
|
|
||||||
lint.ignore = ["ALL"]
|
|
||||||
|
|
||||||
# Enable specific rules
|
|
||||||
lint.select = [
|
|
||||||
"S307", # suspicious-eval-usage
|
|
||||||
"T201", # print-usage
|
|
||||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
|
||||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
|
||||||
"F",
|
|
||||||
]
|
|
||||||
|
|
||||||
exclude = ["*.ipynb"]
|
|
||||||
36
server.py
36
server.py
@@ -27,9 +27,11 @@ from comfy.cli_args import args
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
|
||||||
@@ -43,21 +45,6 @@ async def send_socket_catch_exception(function, message):
|
|||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||||
logging.warning("send error: {}".format(err))
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
def get_comfyui_version():
|
|
||||||
comfyui_version = "unknown"
|
|
||||||
repo_path = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
try:
|
|
||||||
import pygit2
|
|
||||||
repo = pygit2.Repository(repo_path)
|
|
||||||
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
import subprocess
|
|
||||||
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to get ComfyUI version: {e}")
|
|
||||||
return comfyui_version.strip()
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
@@ -153,6 +140,7 @@ class PromptServer():
|
|||||||
|
|
||||||
self.user_manager = UserManager()
|
self.user_manager = UserManager()
|
||||||
self.model_file_manager = ModelFileManager()
|
self.model_file_manager = ModelFileManager()
|
||||||
|
self.custom_node_manager = CustomNodeManager()
|
||||||
self.internal_routes = InternalRoutes(self)
|
self.internal_routes = InternalRoutes(self)
|
||||||
self.supports = ["custom_nodes_from_web"]
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
@@ -341,6 +329,9 @@ class PromptServer():
|
|||||||
original_ref = json.loads(post.get("original_ref"))
|
original_ref = json.loads(post.get("original_ref"))
|
||||||
filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
|
filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
return web.Response(status=400)
|
||||||
|
|
||||||
# validation for security: prevent accessing arbitrary path
|
# validation for security: prevent accessing arbitrary path
|
||||||
if filename[0] == '/' or '..' in filename:
|
if filename[0] == '/' or '..' in filename:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
@@ -382,6 +373,9 @@ class PromptServer():
|
|||||||
filename = request.rel_url.query["filename"]
|
filename = request.rel_url.query["filename"]
|
||||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
return web.Response(status=400)
|
||||||
|
|
||||||
# validation for security: prevent accessing arbitrary path
|
# validation for security: prevent accessing arbitrary path
|
||||||
if filename[0] == '/' or '..' in filename:
|
if filename[0] == '/' or '..' in filename:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
@@ -516,7 +510,7 @@ class PromptServer():
|
|||||||
"os": os.name,
|
"os": os.name,
|
||||||
"ram_total": ram_total,
|
"ram_total": ram_total,
|
||||||
"ram_free": ram_free,
|
"ram_free": ram_free,
|
||||||
"comfyui_version": get_comfyui_version(),
|
"comfyui_version": __version__,
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"pytorch_version": comfy.model_management.torch_version,
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
@@ -697,6 +691,7 @@ class PromptServer():
|
|||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
self.model_file_manager.add_routes(self.routes)
|
self.model_file_manager.add_routes(self.routes)
|
||||||
|
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||||
|
|
||||||
# Prefix every route with /api for easier matching for delegation.
|
# Prefix every route with /api for easier matching for delegation.
|
||||||
@@ -713,6 +708,7 @@ class PromptServer():
|
|||||||
self.app.add_routes(api_routes)
|
self.app.add_routes(api_routes)
|
||||||
self.app.add_routes(self.routes)
|
self.app.add_routes(self.routes)
|
||||||
|
|
||||||
|
# Add routes from web extensions.
|
||||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||||
|
|
||||||
@@ -803,7 +799,7 @@ class PromptServer():
|
|||||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||||
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||||
|
|
||||||
async def start_multi_address(self, addresses, call_on_start=None):
|
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
|
||||||
runner = web.AppRunner(self.app, access_log=None)
|
runner = web.AppRunner(self.app, access_log=None)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
ssl_ctx = None
|
ssl_ctx = None
|
||||||
@@ -814,7 +810,8 @@ class PromptServer():
|
|||||||
keyfile=args.tls_keyfile)
|
keyfile=args.tls_keyfile)
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
|
|
||||||
logging.info("Starting server\n")
|
if verbose:
|
||||||
|
logging.info("Starting server\n")
|
||||||
for addr in addresses:
|
for addr in addresses:
|
||||||
address = addr[0]
|
address = addr[0]
|
||||||
port = addr[1]
|
port = addr[1]
|
||||||
@@ -830,7 +827,8 @@ class PromptServer():
|
|||||||
else:
|
else:
|
||||||
address_print = address
|
address_print = address
|
||||||
|
|
||||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
if verbose:
|
||||||
|
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||||
|
|
||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(scheme, self.address, self.port)
|
call_on_start(scheme, self.address, self.port)
|
||||||
|
|||||||
147
tests-unit/app_test/custom_node_manager_test.py
Normal file
147
tests-unit/app_test/custom_node_manager_test.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
from unittest.mock import patch
|
||||||
|
from app.custom_node_manager import CustomNodeManager
|
||||||
|
import json
|
||||||
|
|
||||||
|
pytestmark = (
|
||||||
|
pytest.mark.asyncio
|
||||||
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def custom_node_manager():
|
||||||
|
return CustomNodeManager()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(custom_node_manager):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
custom_node_manager.add_routes(
|
||||||
|
routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")]
|
||||||
|
)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_workflow_templates(aiohttp_client, app, tmp_path):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
# Setup temporary custom nodes file structure with 1 workflow file
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||||
|
example_workflows_dir = (
|
||||||
|
custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
|
||||||
|
)
|
||||||
|
example_workflows_dir.mkdir(parents=True)
|
||||||
|
template_file = example_workflows_dir / "workflow1.json"
|
||||||
|
template_file.write_text("")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"folder_paths.folder_names_and_paths",
|
||||||
|
{"custom_nodes": ([str(custom_nodes_dir)], None)},
|
||||||
|
):
|
||||||
|
response = await client.get("/workflow_templates")
|
||||||
|
assert response.status == 200
|
||||||
|
workflows_dict = await response.json()
|
||||||
|
assert isinstance(workflows_dict, dict)
|
||||||
|
assert "ComfyUI-TestExtension1" in workflows_dict
|
||||||
|
assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list)
|
||||||
|
assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_build_translations_empty_when_no_locales(custom_node_manager, tmp_path):
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||||
|
custom_nodes_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
with patch("folder_paths.get_folder_paths", return_value=[str(custom_nodes_dir)]):
|
||||||
|
translations = custom_node_manager.build_translations()
|
||||||
|
assert translations == {}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_build_translations_loads_all_files(custom_node_manager, tmp_path):
|
||||||
|
# Setup test directory structure
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes" / "test-extension"
|
||||||
|
locales_dir = custom_nodes_dir / "locales" / "en"
|
||||||
|
locales_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create test translation files
|
||||||
|
main_content = {"title": "Test Extension"}
|
||||||
|
(locales_dir / "main.json").write_text(json.dumps(main_content))
|
||||||
|
|
||||||
|
node_defs = {"node1": "Node 1"}
|
||||||
|
(locales_dir / "nodeDefs.json").write_text(json.dumps(node_defs))
|
||||||
|
|
||||||
|
commands = {"cmd1": "Command 1"}
|
||||||
|
(locales_dir / "commands.json").write_text(json.dumps(commands))
|
||||||
|
|
||||||
|
settings = {"setting1": "Setting 1"}
|
||||||
|
(locales_dir / "settings.json").write_text(json.dumps(settings))
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"folder_paths.get_folder_paths", return_value=[tmp_path / "custom_nodes"]
|
||||||
|
):
|
||||||
|
translations = custom_node_manager.build_translations()
|
||||||
|
|
||||||
|
assert translations == {
|
||||||
|
"en": {
|
||||||
|
"title": "Test Extension",
|
||||||
|
"nodeDefs": {"node1": "Node 1"},
|
||||||
|
"commands": {"cmd1": "Command 1"},
|
||||||
|
"settings": {"setting1": "Setting 1"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_build_translations_handles_invalid_json(custom_node_manager, tmp_path):
|
||||||
|
# Setup test directory structure
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes" / "test-extension"
|
||||||
|
locales_dir = custom_nodes_dir / "locales" / "en"
|
||||||
|
locales_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create valid main.json
|
||||||
|
main_content = {"title": "Test Extension"}
|
||||||
|
(locales_dir / "main.json").write_text(json.dumps(main_content))
|
||||||
|
|
||||||
|
# Create invalid JSON file
|
||||||
|
(locales_dir / "nodeDefs.json").write_text("invalid json{")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"folder_paths.get_folder_paths", return_value=[tmp_path / "custom_nodes"]
|
||||||
|
):
|
||||||
|
translations = custom_node_manager.build_translations()
|
||||||
|
|
||||||
|
assert translations == {
|
||||||
|
"en": {
|
||||||
|
"title": "Test Extension",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_build_translations_merges_multiple_extensions(
|
||||||
|
custom_node_manager, tmp_path
|
||||||
|
):
|
||||||
|
# Setup test directory structure for two extensions
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||||
|
ext1_dir = custom_nodes_dir / "extension1" / "locales" / "en"
|
||||||
|
ext2_dir = custom_nodes_dir / "extension2" / "locales" / "en"
|
||||||
|
ext1_dir.mkdir(parents=True)
|
||||||
|
ext2_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create translation files for extension 1
|
||||||
|
ext1_main = {"title": "Extension 1", "shared": "Original"}
|
||||||
|
(ext1_dir / "main.json").write_text(json.dumps(ext1_main))
|
||||||
|
|
||||||
|
# Create translation files for extension 2
|
||||||
|
ext2_main = {"description": "Extension 2", "shared": "Override"}
|
||||||
|
(ext2_dir / "main.json").write_text(json.dumps(ext2_main))
|
||||||
|
|
||||||
|
with patch("folder_paths.get_folder_paths", return_value=[str(custom_nodes_dir)]):
|
||||||
|
translations = custom_node_manager.build_translations()
|
||||||
|
|
||||||
|
assert translations == {
|
||||||
|
"en": {
|
||||||
|
"title": "Extension 1",
|
||||||
|
"description": "Extension 2",
|
||||||
|
"shared": "Override", # Second extension should override first
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,22 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from unittest.mock import Mock, patch, mock_open
|
from unittest.mock import Mock, patch, mock_open
|
||||||
|
|
||||||
from utils.extra_config import load_extra_path_config
|
from utils.extra_config import load_extra_path_config
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def clear_folder_paths():
|
||||||
|
# Clear the global dictionary before each test to ensure isolation
|
||||||
|
original = folder_paths.folder_names_and_paths.copy()
|
||||||
|
folder_paths.folder_names_and_paths.clear()
|
||||||
|
yield
|
||||||
|
folder_paths.folder_names_and_paths = original
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_yaml_content():
|
def mock_yaml_content():
|
||||||
return {
|
return {
|
||||||
@@ -15,10 +26,12 @@ def mock_yaml_content():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_expanded_home():
|
def mock_expanded_home():
|
||||||
return '/home/user'
|
return '/home/user'
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def yaml_config_with_appdata():
|
def yaml_config_with_appdata():
|
||||||
return """
|
return """
|
||||||
@@ -27,20 +40,33 @@ def yaml_config_with_appdata():
|
|||||||
checkpoints: 'models/checkpoints'
|
checkpoints: 'models/checkpoints'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||||
return yaml.safe_load(yaml_config_with_appdata)
|
return yaml.safe_load(yaml_config_with_appdata)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_expandvars_appdata():
|
def mock_expandvars_appdata():
|
||||||
mock = Mock()
|
mock = Mock()
|
||||||
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
|
||||||
|
def expandvars(path):
|
||||||
|
if '%APPDATA%' in path:
|
||||||
|
if sys.platform == 'win32':
|
||||||
|
return path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||||
|
else:
|
||||||
|
return path.replace('%APPDATA%', '/Users/TestUser/AppData/Roaming')
|
||||||
|
return path
|
||||||
|
|
||||||
|
mock.side_effect = expandvars
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_add_model_folder_path():
|
def mock_add_model_folder_path():
|
||||||
return Mock()
|
return Mock()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_expanduser(mock_expanded_home):
|
def mock_expanduser(mock_expanded_home):
|
||||||
def _expanduser(path):
|
def _expanduser(path):
|
||||||
@@ -49,10 +75,12 @@ def mock_expanduser(mock_expanded_home):
|
|||||||
return path
|
return path
|
||||||
return _expanduser
|
return _expanduser
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_yaml_safe_load(mock_yaml_content):
|
def mock_yaml_safe_load(mock_yaml_content):
|
||||||
return Mock(return_value=mock_yaml_content)
|
return Mock(return_value=mock_yaml_content)
|
||||||
|
|
||||||
|
|
||||||
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||||
def test_load_extra_model_paths_expands_userpath(
|
def test_load_extra_model_paths_expands_userpath(
|
||||||
mock_file,
|
mock_file,
|
||||||
@@ -88,6 +116,7 @@ def test_load_extra_model_paths_expands_userpath(
|
|||||||
# Check if open was called with the correct file path
|
# Check if open was called with the correct file path
|
||||||
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||||
|
|
||||||
|
|
||||||
@patch('builtins.open', new_callable=mock_open)
|
@patch('builtins.open', new_callable=mock_open)
|
||||||
def test_load_extra_model_paths_expands_appdata(
|
def test_load_extra_model_paths_expands_appdata(
|
||||||
mock_file,
|
mock_file,
|
||||||
@@ -111,7 +140,10 @@ def test_load_extra_model_paths_expands_appdata(
|
|||||||
dummy_yaml_file_name = 'dummy_path.yaml'
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
load_extra_path_config(dummy_yaml_file_name)
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
if sys.platform == "win32":
|
||||||
|
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||||
|
else:
|
||||||
|
expected_base_path = '/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
||||||
]
|
]
|
||||||
@@ -124,3 +156,148 @@ def test_load_extra_model_paths_expands_appdata(
|
|||||||
|
|
||||||
# Verify that expandvars was called
|
# Verify that expandvars was called
|
||||||
assert mock_expandvars_appdata.called
|
assert mock_expandvars_appdata.called
|
||||||
|
|
||||||
|
|
||||||
|
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
|
||||||
|
@patch("yaml.safe_load")
|
||||||
|
def test_load_extra_path_config_relative_base_path(
|
||||||
|
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that when 'base_path' is a relative path in the YAML, it is joined to the YAML file directory, and then
|
||||||
|
the items in the config are correctly converted to absolute paths.
|
||||||
|
"""
|
||||||
|
sub_folder = "./my_rel_base"
|
||||||
|
config_data = {
|
||||||
|
"some_model_folder": {
|
||||||
|
"base_path": sub_folder,
|
||||||
|
"is_default": True,
|
||||||
|
"checkpoints": "checkpoints",
|
||||||
|
"some_key": "some_value"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_yaml_load.return_value = config_data
|
||||||
|
|
||||||
|
dummy_yaml_name = "dummy_file.yaml"
|
||||||
|
|
||||||
|
def fake_abspath(path):
|
||||||
|
if path == dummy_yaml_name:
|
||||||
|
# If it's the YAML path, treat it like it lives in tmp_path
|
||||||
|
return os.path.join(str(tmp_path), dummy_yaml_name)
|
||||||
|
return os.path.join(str(tmp_path), path) # Otherwise, do a normal join relative to tmp_path
|
||||||
|
|
||||||
|
def fake_dirname(path):
|
||||||
|
# We expect path to be the result of fake_abspath(dummy_yaml_name)
|
||||||
|
if path.endswith(dummy_yaml_name):
|
||||||
|
return str(tmp_path)
|
||||||
|
return os.path.dirname(path)
|
||||||
|
|
||||||
|
monkeypatch.setattr(os.path, "abspath", fake_abspath)
|
||||||
|
monkeypatch.setattr(os.path, "dirname", fake_dirname)
|
||||||
|
|
||||||
|
load_extra_path_config(dummy_yaml_name)
|
||||||
|
|
||||||
|
expected_checkpoints = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "checkpoints"))
|
||||||
|
expected_some_value = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "some_value"))
|
||||||
|
|
||||||
|
actual_paths = folder_paths.folder_names_and_paths["checkpoints"][0]
|
||||||
|
assert len(actual_paths) == 1, "Should have one path added for 'checkpoints'."
|
||||||
|
assert actual_paths[0] == expected_checkpoints
|
||||||
|
|
||||||
|
actual_paths = folder_paths.folder_names_and_paths["some_key"][0]
|
||||||
|
assert len(actual_paths) == 1, "Should have one path added for 'some_key'."
|
||||||
|
assert actual_paths[0] == expected_some_value
|
||||||
|
|
||||||
|
|
||||||
|
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
|
||||||
|
@patch("yaml.safe_load")
|
||||||
|
def test_load_extra_path_config_absolute_base_path(
|
||||||
|
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that when 'base_path' is an absolute path, each subdirectory is joined with that absolute path,
|
||||||
|
rather than being relative to the YAML's directory.
|
||||||
|
"""
|
||||||
|
abs_base = os.path.join(str(tmp_path), "abs_base")
|
||||||
|
config_data = {
|
||||||
|
"some_absolute_folder": {
|
||||||
|
"base_path": abs_base, # <-- absolute
|
||||||
|
"is_default": True,
|
||||||
|
"loras": "loras_folder",
|
||||||
|
"embeddings": "embeddings_folder"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_yaml_load.return_value = config_data
|
||||||
|
|
||||||
|
dummy_yaml_name = "dummy_abs.yaml"
|
||||||
|
|
||||||
|
def fake_abspath(path):
|
||||||
|
if path == dummy_yaml_name:
|
||||||
|
# If it's the YAML path, treat it like it is in tmp_path
|
||||||
|
return os.path.join(str(tmp_path), dummy_yaml_name)
|
||||||
|
return path # For absolute base, we just return path directly
|
||||||
|
|
||||||
|
def fake_dirname(path):
|
||||||
|
return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
|
||||||
|
|
||||||
|
monkeypatch.setattr(os.path, "abspath", fake_abspath)
|
||||||
|
monkeypatch.setattr(os.path, "dirname", fake_dirname)
|
||||||
|
|
||||||
|
load_extra_path_config(dummy_yaml_name)
|
||||||
|
|
||||||
|
# Expect the final paths to be <abs_base>/loras_folder and <abs_base>/embeddings_folder
|
||||||
|
expected_loras = os.path.join(abs_base, "loras_folder")
|
||||||
|
expected_embeddings = os.path.join(abs_base, "embeddings_folder")
|
||||||
|
|
||||||
|
actual_loras = folder_paths.folder_names_and_paths["loras"][0]
|
||||||
|
assert len(actual_loras) == 1, "Should have one path for 'loras'."
|
||||||
|
assert actual_loras[0] == os.path.abspath(expected_loras)
|
||||||
|
|
||||||
|
actual_embeddings = folder_paths.folder_names_and_paths["embeddings"][0]
|
||||||
|
assert len(actual_embeddings) == 1, "Should have one path for 'embeddings'."
|
||||||
|
assert actual_embeddings[0] == os.path.abspath(expected_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
|
||||||
|
@patch("yaml.safe_load")
|
||||||
|
def test_load_extra_path_config_no_base_path(
|
||||||
|
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that if 'base_path' is not present, each path is joined
|
||||||
|
with the directory of the YAML file (unless it's already absolute).
|
||||||
|
"""
|
||||||
|
config_data = {
|
||||||
|
"some_folder_without_base": {
|
||||||
|
"is_default": True,
|
||||||
|
"text_encoders": "clip",
|
||||||
|
"diffusion_models": "unet"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_yaml_load.return_value = config_data
|
||||||
|
|
||||||
|
dummy_yaml_name = "dummy_no_base.yaml"
|
||||||
|
|
||||||
|
def fake_abspath(path):
|
||||||
|
if path == dummy_yaml_name:
|
||||||
|
return os.path.join(str(tmp_path), dummy_yaml_name)
|
||||||
|
return os.path.join(str(tmp_path), path)
|
||||||
|
|
||||||
|
def fake_dirname(path):
|
||||||
|
return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
|
||||||
|
|
||||||
|
monkeypatch.setattr(os.path, "abspath", fake_abspath)
|
||||||
|
monkeypatch.setattr(os.path, "dirname", fake_dirname)
|
||||||
|
|
||||||
|
load_extra_path_config(dummy_yaml_name)
|
||||||
|
|
||||||
|
expected_clip = os.path.join(str(tmp_path), "clip")
|
||||||
|
expected_unet = os.path.join(str(tmp_path), "unet")
|
||||||
|
|
||||||
|
actual_text_encoders = folder_paths.folder_names_and_paths["text_encoders"][0]
|
||||||
|
assert len(actual_text_encoders) == 1, "Should have one path for 'text_encoders'."
|
||||||
|
assert actual_text_encoders[0] == os.path.abspath(expected_clip)
|
||||||
|
|
||||||
|
actual_diffusion = folder_paths.folder_names_and_paths["diffusion_models"][0]
|
||||||
|
assert len(actual_diffusion) == 1, "Should have one path for 'diffusion_models'."
|
||||||
|
assert actual_diffusion[0] == os.path.abspath(expected_unet)
|
||||||
|
|||||||
71
tests-unit/utils/json_util_test.py
Normal file
71
tests-unit/utils/json_util_test.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from utils.json_util import merge_json_recursive
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_simple_dicts():
|
||||||
|
base = {"a": 1, "b": 2}
|
||||||
|
update = {"b": 3, "c": 4}
|
||||||
|
expected = {"a": 1, "b": 3, "c": 4}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_nested_dicts():
|
||||||
|
base = {"a": {"x": 1, "y": 2}, "b": 3}
|
||||||
|
update = {"a": {"y": 4, "z": 5}}
|
||||||
|
expected = {"a": {"x": 1, "y": 4, "z": 5}, "b": 3}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lists():
|
||||||
|
base = {"a": [1, 2], "b": 3}
|
||||||
|
update = {"a": [3, 4]}
|
||||||
|
expected = {"a": [1, 2, 3, 4], "b": 3}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_nested_lists():
|
||||||
|
base = {"a": {"x": [1, 2]}}
|
||||||
|
update = {"a": {"x": [3, 4]}}
|
||||||
|
expected = {"a": {"x": [1, 2, 3, 4]}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_mixed_types():
|
||||||
|
base = {"a": [1, 2], "b": {"x": 1}}
|
||||||
|
update = {"a": [3], "b": {"y": 2}}
|
||||||
|
expected = {"a": [1, 2, 3], "b": {"x": 1, "y": 2}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_overwrite_non_dict():
|
||||||
|
base = {"a": 1}
|
||||||
|
update = {"a": {"x": 2}}
|
||||||
|
expected = {"a": {"x": 2}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_empty_dicts():
|
||||||
|
base = {}
|
||||||
|
update = {"a": 1}
|
||||||
|
expected = {"a": 1}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_none_values():
|
||||||
|
base = {"a": None}
|
||||||
|
update = {"a": {"x": 1}}
|
||||||
|
expected = {"a": {"x": 1}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_different_types():
|
||||||
|
base = {"a": [1, 2]}
|
||||||
|
update = {"a": "string"}
|
||||||
|
expected = {"a": "string"}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_complex_nested():
|
||||||
|
base = {"a": [1, 2], "b": {"x": [3, 4], "y": {"p": 1}}}
|
||||||
|
update = {"a": [5], "b": {"x": [6], "y": {"q": 2}}}
|
||||||
|
expected = {"a": [1, 2, 5], "b": {"x": [3, 4, 6], "y": {"p": 1, "q": 2}}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
@@ -6,6 +6,7 @@ import logging
|
|||||||
def load_extra_path_config(yaml_path):
|
def load_extra_path_config(yaml_path):
|
||||||
with open(yaml_path, 'r') as stream:
|
with open(yaml_path, 'r') as stream:
|
||||||
config = yaml.safe_load(stream)
|
config = yaml.safe_load(stream)
|
||||||
|
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
|
||||||
for c in config:
|
for c in config:
|
||||||
conf = config[c]
|
conf = config[c]
|
||||||
if conf is None:
|
if conf is None:
|
||||||
@@ -14,6 +15,8 @@ def load_extra_path_config(yaml_path):
|
|||||||
if "base_path" in conf:
|
if "base_path" in conf:
|
||||||
base_path = conf.pop("base_path")
|
base_path = conf.pop("base_path")
|
||||||
base_path = os.path.expandvars(os.path.expanduser(base_path))
|
base_path = os.path.expandvars(os.path.expanduser(base_path))
|
||||||
|
if not os.path.isabs(base_path):
|
||||||
|
base_path = os.path.abspath(os.path.join(yaml_dir, base_path))
|
||||||
is_default = False
|
is_default = False
|
||||||
if "is_default" in conf:
|
if "is_default" in conf:
|
||||||
is_default = conf.pop("is_default")
|
is_default = conf.pop("is_default")
|
||||||
@@ -22,10 +25,9 @@ def load_extra_path_config(yaml_path):
|
|||||||
if len(y) == 0:
|
if len(y) == 0:
|
||||||
continue
|
continue
|
||||||
full_path = y
|
full_path = y
|
||||||
if base_path is not None:
|
if base_path:
|
||||||
full_path = os.path.join(base_path, full_path)
|
full_path = os.path.join(base_path, full_path)
|
||||||
elif not os.path.isabs(full_path):
|
elif not os.path.isabs(full_path):
|
||||||
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
|
|
||||||
full_path = os.path.abspath(os.path.join(yaml_dir, y))
|
full_path = os.path.abspath(os.path.join(yaml_dir, y))
|
||||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||||
folder_paths.add_model_folder_path(x, full_path, is_default)
|
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||||
|
|||||||
26
utils/json_util.py
Normal file
26
utils/json_util.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
def merge_json_recursive(base, update):
|
||||||
|
"""Recursively merge two JSON-like objects.
|
||||||
|
- Dictionaries are merged recursively
|
||||||
|
- Lists are concatenated
|
||||||
|
- Other types are overwritten by the update value
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base: Base JSON-like object
|
||||||
|
update: Update JSON-like object to merge into base
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged JSON-like object
|
||||||
|
"""
|
||||||
|
if not isinstance(base, dict) or not isinstance(update, dict):
|
||||||
|
if isinstance(base, list) and isinstance(update, list):
|
||||||
|
return base + update
|
||||||
|
return update
|
||||||
|
|
||||||
|
merged = base.copy()
|
||||||
|
for key, value in update.items():
|
||||||
|
if key in merged:
|
||||||
|
merged[key] = merge_json_recursive(merged[key], value)
|
||||||
|
else:
|
||||||
|
merged[key] = value
|
||||||
|
|
||||||
|
return merged
|
||||||
54
web/assets/BaseViewTemplate-BhQMaVFP.js
generated
vendored
Normal file
54
web/assets/BaseViewTemplate-BhQMaVFP.js
generated
vendored
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import { d as defineComponent, ad as ref, t as onMounted, bT as isElectron, bV as electronAPI, af as nextTick, o as openBlock, f as createElementBlock, i as withDirectives, v as vShow, m as createBaseVNode, M as renderSlot, V as normalizeClass } from "./index-QvfM__ze.js";
|
||||||
|
const _hoisted_1 = { class: "flex-grow w-full flex items-center justify-center overflow-auto" };
|
||||||
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "BaseViewTemplate",
|
||||||
|
props: {
|
||||||
|
dark: { type: Boolean, default: false }
|
||||||
|
},
|
||||||
|
setup(__props) {
|
||||||
|
const props = __props;
|
||||||
|
const darkTheme = {
|
||||||
|
color: "rgba(0, 0, 0, 0)",
|
||||||
|
symbolColor: "#d4d4d4"
|
||||||
|
};
|
||||||
|
const lightTheme = {
|
||||||
|
color: "rgba(0, 0, 0, 0)",
|
||||||
|
symbolColor: "#171717"
|
||||||
|
};
|
||||||
|
const topMenuRef = ref(null);
|
||||||
|
const isNativeWindow = ref(false);
|
||||||
|
onMounted(async () => {
|
||||||
|
if (isElectron()) {
|
||||||
|
const windowStyle = await electronAPI().Config.getWindowStyle();
|
||||||
|
isNativeWindow.value = windowStyle === "custom";
|
||||||
|
await nextTick();
|
||||||
|
electronAPI().changeTheme({
|
||||||
|
...props.dark ? darkTheme : lightTheme,
|
||||||
|
height: topMenuRef.value.getBoundingClientRect().height
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
return openBlock(), createElementBlock("div", {
|
||||||
|
class: normalizeClass(["font-sans w-screen h-screen flex flex-col pointer-events-auto", [
|
||||||
|
props.dark ? "text-neutral-300 bg-neutral-900 dark-theme" : "text-neutral-900 bg-neutral-300"
|
||||||
|
]])
|
||||||
|
}, [
|
||||||
|
withDirectives(createBaseVNode("div", {
|
||||||
|
ref_key: "topMenuRef",
|
||||||
|
ref: topMenuRef,
|
||||||
|
class: "app-drag w-full h-[var(--comfy-topbar-height)]"
|
||||||
|
}, null, 512), [
|
||||||
|
[vShow, isNativeWindow.value]
|
||||||
|
]),
|
||||||
|
createBaseVNode("div", _hoisted_1, [
|
||||||
|
renderSlot(_ctx.$slots, "default")
|
||||||
|
])
|
||||||
|
], 2);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
export {
|
||||||
|
_sfc_main as _
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=BaseViewTemplate-BhQMaVFP.js.map
|
||||||
22
web/assets/DesktopStartView-le6AjGZr.js
generated
vendored
Normal file
22
web/assets/DesktopStartView-le6AjGZr.js
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import { d as defineComponent, o as openBlock, J as createBlock, P as withCtx, m as createBaseVNode, k as createVNode, j as unref, ch as script } from "./index-QvfM__ze.js";
|
||||||
|
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BhQMaVFP.js";
|
||||||
|
const _hoisted_1 = { class: "max-w-screen-sm w-screen p-8" };
|
||||||
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "DesktopStartView",
|
||||||
|
setup(__props) {
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
return openBlock(), createBlock(_sfc_main$1, { dark: "" }, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createBaseVNode("div", _hoisted_1, [
|
||||||
|
createVNode(unref(script), { mode: "indeterminate" })
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
export {
|
||||||
|
_sfc_main as default
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=DesktopStartView-le6AjGZr.js.map
|
||||||
1
web/assets/DownloadGitView-B3f7KHY3.js.map
generated
vendored
1
web/assets/DownloadGitView-B3f7KHY3.js.map
generated
vendored
@@ -1 +0,0 @@
|
|||||||
{"version":3,"file":"DownloadGitView-B3f7KHY3.js","sources":["../../src/views/DownloadGitView.vue"],"sourcesContent":["<template>\n <div\n class=\"font-sans w-screen h-screen mx-0 grid place-items-center justify-center items-center text-neutral-900 bg-neutral-300 pointer-events-auto\"\n >\n <div\n class=\"col-start-1 h-screen row-start-1 place-content-center mx-auto overflow-y-auto\"\n >\n <div\n class=\"max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding\"\n >\n <!-- Header -->\n <h1 class=\"mt-24 text-4xl font-bold text-red-500\">\n {{ $t('downloadGit.title') }}\n </h1>\n\n <!-- Message -->\n <div class=\"space-y-4\">\n <p class=\"text-xl\">\n {{ $t('downloadGit.message') }}\n </p>\n <p class=\"text-xl\">\n {{ $t('downloadGit.instructions') }}\n </p>\n <p class=\"text-m\">\n {{ $t('downloadGit.warning') }}\n </p>\n </div>\n\n <!-- Actions -->\n <div class=\"flex gap-4 flex-row-reverse\">\n <Button\n :label=\"$t('downloadGit.gitWebsite')\"\n icon=\"pi pi-external-link\"\n icon-pos=\"right\"\n @click=\"openGitDownloads\"\n severity=\"primary\"\n />\n <Button\n :label=\"$t('downloadGit.skip')\"\n icon=\"pi pi-exclamation-triangle\"\n @click=\"skipGit\"\n severity=\"secondary\"\n />\n </div>\n </div>\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { useRouter } from 'vue-router'\n\nconst openGitDownloads = () => {\n window.open('https://git-scm.com/downloads/', '_blank')\n}\n\nconst skipGit = () => {\n console.warn('pushing')\n const router = useRouter()\n router.push('install')\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;AAqDA,UAAM,mBAAmB,6BAAM;AACtB,aAAA,KAAK,kCAAkC,QAAQ;AAAA,IAAA,GAD/B;AAIzB,UAAM,UAAU,6BAAM;AACpB,cAAQ,KAAK,SAAS;AACtB,YAAM,SAAS;AACf,aAAO,KAAK,SAAS;AAAA,IAAA,GAHP;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
|
||||||
44
web/assets/DownloadGitView-B3f7KHY3.js → web/assets/DownloadGitView-rPK_vYgU.js
generated
vendored
44
web/assets/DownloadGitView-B3f7KHY3.js → web/assets/DownloadGitView-rPK_vYgU.js
generated
vendored
@@ -1,15 +1,14 @@
|
|||||||
var __defProp = Object.defineProperty;
|
var __defProp = Object.defineProperty;
|
||||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
import { a as defineComponent, f as openBlock, g as createElementBlock, A as createBaseVNode, a8 as toDisplayString, h as createVNode, z as unref, D as script, bU as useRouter } from "./index-DIU5yZe9.js";
|
import { d as defineComponent, o as openBlock, J as createBlock, P as withCtx, m as createBaseVNode, Z as toDisplayString, k as createVNode, j as unref, l as script, c2 as useRouter } from "./index-QvfM__ze.js";
|
||||||
const _hoisted_1 = { class: "font-sans w-screen h-screen mx-0 grid place-items-center justify-center items-center text-neutral-900 bg-neutral-300 pointer-events-auto" };
|
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BhQMaVFP.js";
|
||||||
const _hoisted_2 = { class: "col-start-1 h-screen row-start-1 place-content-center mx-auto overflow-y-auto" };
|
const _hoisted_1 = { class: "max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding" };
|
||||||
const _hoisted_3 = { class: "max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding" };
|
const _hoisted_2 = { class: "mt-24 text-4xl font-bold text-red-500" };
|
||||||
const _hoisted_4 = { class: "mt-24 text-4xl font-bold text-red-500" };
|
const _hoisted_3 = { class: "space-y-4" };
|
||||||
const _hoisted_5 = { class: "space-y-4" };
|
const _hoisted_4 = { class: "text-xl" };
|
||||||
const _hoisted_6 = { class: "text-xl" };
|
const _hoisted_5 = { class: "text-xl" };
|
||||||
const _hoisted_7 = { class: "text-xl" };
|
const _hoisted_6 = { class: "text-m" };
|
||||||
const _hoisted_8 = { class: "text-m" };
|
const _hoisted_7 = { class: "flex gap-4 flex-row-reverse" };
|
||||||
const _hoisted_9 = { class: "flex gap-4 flex-row-reverse" };
|
|
||||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
__name: "DownloadGitView",
|
__name: "DownloadGitView",
|
||||||
setup(__props) {
|
setup(__props) {
|
||||||
@@ -22,16 +21,16 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
router.push("install");
|
router.push("install");
|
||||||
}, "skipGit");
|
}, "skipGit");
|
||||||
return (_ctx, _cache) => {
|
return (_ctx, _cache) => {
|
||||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
return openBlock(), createBlock(_sfc_main$1, null, {
|
||||||
createBaseVNode("div", _hoisted_2, [
|
default: withCtx(() => [
|
||||||
createBaseVNode("div", _hoisted_3, [
|
createBaseVNode("div", _hoisted_1, [
|
||||||
createBaseVNode("h1", _hoisted_4, toDisplayString(_ctx.$t("downloadGit.title")), 1),
|
createBaseVNode("h1", _hoisted_2, toDisplayString(_ctx.$t("downloadGit.title")), 1),
|
||||||
createBaseVNode("div", _hoisted_5, [
|
createBaseVNode("div", _hoisted_3, [
|
||||||
createBaseVNode("p", _hoisted_6, toDisplayString(_ctx.$t("downloadGit.message")), 1),
|
createBaseVNode("p", _hoisted_4, toDisplayString(_ctx.$t("downloadGit.message")), 1),
|
||||||
createBaseVNode("p", _hoisted_7, toDisplayString(_ctx.$t("downloadGit.instructions")), 1),
|
createBaseVNode("p", _hoisted_5, toDisplayString(_ctx.$t("downloadGit.instructions")), 1),
|
||||||
createBaseVNode("p", _hoisted_8, toDisplayString(_ctx.$t("downloadGit.warning")), 1)
|
createBaseVNode("p", _hoisted_6, toDisplayString(_ctx.$t("downloadGit.warning")), 1)
|
||||||
]),
|
]),
|
||||||
createBaseVNode("div", _hoisted_9, [
|
createBaseVNode("div", _hoisted_7, [
|
||||||
createVNode(unref(script), {
|
createVNode(unref(script), {
|
||||||
label: _ctx.$t("downloadGit.gitWebsite"),
|
label: _ctx.$t("downloadGit.gitWebsite"),
|
||||||
icon: "pi pi-external-link",
|
icon: "pi pi-external-link",
|
||||||
@@ -47,12 +46,13 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
}, null, 8, ["label"])
|
}, null, 8, ["label"])
|
||||||
])
|
])
|
||||||
])
|
])
|
||||||
])
|
]),
|
||||||
]);
|
_: 1
|
||||||
|
});
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
export {
|
export {
|
||||||
_sfc_main as default
|
_sfc_main as default
|
||||||
};
|
};
|
||||||
//# sourceMappingURL=DownloadGitView-B3f7KHY3.js.map
|
//# sourceMappingURL=DownloadGitView-rPK_vYgU.js.map
|
||||||
87
web/assets/ExtensionPanel-ByeZ01RF.js → web/assets/ExtensionPanel-3jWrm6Zi.js
generated
vendored
87
web/assets/ExtensionPanel-ByeZ01RF.js → web/assets/ExtensionPanel-3jWrm6Zi.js
generated
vendored
@@ -1,8 +1,8 @@
|
|||||||
var __defProp = Object.defineProperty;
|
var __defProp = Object.defineProperty;
|
||||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
import { a as defineComponent, r as ref, ck as FilterMatchMode, co as useExtensionStore, u as useSettingStore, o as onMounted, q as computed, f as openBlock, x as createBlock, y as withCtx, h as createVNode, cl as SearchBox, z as unref, bW as script, A as createBaseVNode, g as createElementBlock, Q as renderList, a8 as toDisplayString, ay as createTextVNode, P as Fragment, D as script$1, i as createCommentVNode, c5 as script$3, cm as _sfc_main$1 } from "./index-DIU5yZe9.js";
|
import { d as defineComponent, ad as ref, cu as FilterMatchMode, cz as useExtensionStore, a as useSettingStore, t as onMounted, c as computed, o as openBlock, J as createBlock, P as withCtx, k as createVNode, cv as SearchBox, j as unref, c6 as script, m as createBaseVNode, f as createElementBlock, I as renderList, Z as toDisplayString, aG as createTextVNode, H as Fragment, l as script$1, L as createCommentVNode, aK as script$3, b8 as script$4, cc as script$5, cw as _sfc_main$1 } from "./index-QvfM__ze.js";
|
||||||
import { s as script$2, a as script$4 } from "./index-D3u7l7ha.js";
|
import { s as script$2, a as script$6 } from "./index-DpF-ptbJ.js";
|
||||||
import "./index-d698Brhb.js";
|
import "./index-Q1cQr26V.js";
|
||||||
const _hoisted_1 = { class: "flex justify-end" };
|
const _hoisted_1 = { class: "flex justify-end" };
|
||||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
__name: "ExtensionPanel",
|
__name: "ExtensionPanel",
|
||||||
@@ -35,9 +35,49 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
...editingDisabledExtensionNames
|
...editingDisabledExtensionNames
|
||||||
]);
|
]);
|
||||||
}, "updateExtensionStatus");
|
}, "updateExtensionStatus");
|
||||||
|
const enableAllExtensions = /* @__PURE__ */ __name(() => {
|
||||||
|
extensionStore.extensions.forEach((ext) => {
|
||||||
|
if (extensionStore.isExtensionReadOnly(ext.name)) return;
|
||||||
|
editingEnabledExtensions.value[ext.name] = true;
|
||||||
|
});
|
||||||
|
updateExtensionStatus();
|
||||||
|
}, "enableAllExtensions");
|
||||||
|
const disableAllExtensions = /* @__PURE__ */ __name(() => {
|
||||||
|
extensionStore.extensions.forEach((ext) => {
|
||||||
|
if (extensionStore.isExtensionReadOnly(ext.name)) return;
|
||||||
|
editingEnabledExtensions.value[ext.name] = false;
|
||||||
|
});
|
||||||
|
updateExtensionStatus();
|
||||||
|
}, "disableAllExtensions");
|
||||||
|
const disableThirdPartyExtensions = /* @__PURE__ */ __name(() => {
|
||||||
|
extensionStore.extensions.forEach((ext) => {
|
||||||
|
if (extensionStore.isCoreExtension(ext.name)) return;
|
||||||
|
editingEnabledExtensions.value[ext.name] = false;
|
||||||
|
});
|
||||||
|
updateExtensionStatus();
|
||||||
|
}, "disableThirdPartyExtensions");
|
||||||
const applyChanges = /* @__PURE__ */ __name(() => {
|
const applyChanges = /* @__PURE__ */ __name(() => {
|
||||||
window.location.reload();
|
window.location.reload();
|
||||||
}, "applyChanges");
|
}, "applyChanges");
|
||||||
|
const menu = ref();
|
||||||
|
const contextMenuItems = [
|
||||||
|
{
|
||||||
|
label: "Enable All",
|
||||||
|
icon: "pi pi-check",
|
||||||
|
command: enableAllExtensions
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: "Disable All",
|
||||||
|
icon: "pi pi-times",
|
||||||
|
command: disableAllExtensions
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: "Disable 3rd Party",
|
||||||
|
icon: "pi pi-times",
|
||||||
|
command: disableThirdPartyExtensions,
|
||||||
|
disabled: !extensionStore.hasThirdPartyExtensions
|
||||||
|
}
|
||||||
|
];
|
||||||
return (_ctx, _cache) => {
|
return (_ctx, _cache) => {
|
||||||
return openBlock(), createBlock(_sfc_main$1, {
|
return openBlock(), createBlock(_sfc_main$1, {
|
||||||
value: "Extension",
|
value: "Extension",
|
||||||
@@ -52,7 +92,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
hasChanges.value ? (openBlock(), createBlock(unref(script), {
|
hasChanges.value ? (openBlock(), createBlock(unref(script), {
|
||||||
key: 0,
|
key: 0,
|
||||||
severity: "info",
|
severity: "info",
|
||||||
"pt:text": "w-full"
|
"pt:text": "w-full",
|
||||||
|
class: "max-h-96 overflow-y-auto"
|
||||||
}, {
|
}, {
|
||||||
default: withCtx(() => [
|
default: withCtx(() => [
|
||||||
createBaseVNode("ul", null, [
|
createBaseVNode("ul", null, [
|
||||||
@@ -78,7 +119,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
})) : createCommentVNode("", true)
|
})) : createCommentVNode("", true)
|
||||||
]),
|
]),
|
||||||
default: withCtx(() => [
|
default: withCtx(() => [
|
||||||
createVNode(unref(script$4), {
|
createVNode(unref(script$6), {
|
||||||
value: unref(extensionStore).extensions,
|
value: unref(extensionStore).extensions,
|
||||||
stripedRows: "",
|
stripedRows: "",
|
||||||
size: "small",
|
size: "small",
|
||||||
@@ -86,19 +127,43 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
}, {
|
}, {
|
||||||
default: withCtx(() => [
|
default: withCtx(() => [
|
||||||
createVNode(unref(script$2), {
|
createVNode(unref(script$2), {
|
||||||
field: "name",
|
|
||||||
header: _ctx.$t("g.extensionName"),
|
header: _ctx.$t("g.extensionName"),
|
||||||
sortable: ""
|
sortable: "",
|
||||||
}, null, 8, ["header"]),
|
field: "name"
|
||||||
|
}, {
|
||||||
|
body: withCtx((slotProps) => [
|
||||||
|
createTextVNode(toDisplayString(slotProps.data.name) + " ", 1),
|
||||||
|
unref(extensionStore).isCoreExtension(slotProps.data.name) ? (openBlock(), createBlock(unref(script$3), {
|
||||||
|
key: 0,
|
||||||
|
value: "Core"
|
||||||
|
})) : createCommentVNode("", true)
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
}, 8, ["header"]),
|
||||||
createVNode(unref(script$2), { pt: {
|
createVNode(unref(script$2), { pt: {
|
||||||
|
headerCell: "flex items-center justify-end",
|
||||||
bodyCell: "flex items-center justify-end"
|
bodyCell: "flex items-center justify-end"
|
||||||
} }, {
|
} }, {
|
||||||
|
header: withCtx(() => [
|
||||||
|
createVNode(unref(script$1), {
|
||||||
|
icon: "pi pi-ellipsis-h",
|
||||||
|
text: "",
|
||||||
|
severity: "secondary",
|
||||||
|
onClick: _cache[1] || (_cache[1] = ($event) => menu.value.show($event))
|
||||||
|
}),
|
||||||
|
createVNode(unref(script$4), {
|
||||||
|
ref_key: "menu",
|
||||||
|
ref: menu,
|
||||||
|
model: contextMenuItems
|
||||||
|
}, null, 512)
|
||||||
|
]),
|
||||||
body: withCtx((slotProps) => [
|
body: withCtx((slotProps) => [
|
||||||
createVNode(unref(script$3), {
|
createVNode(unref(script$5), {
|
||||||
|
disabled: unref(extensionStore).isExtensionReadOnly(slotProps.data.name),
|
||||||
modelValue: editingEnabledExtensions.value[slotProps.data.name],
|
modelValue: editingEnabledExtensions.value[slotProps.data.name],
|
||||||
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
|
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
|
||||||
onChange: updateExtensionStatus
|
onChange: updateExtensionStatus
|
||||||
}, null, 8, ["modelValue", "onUpdate:modelValue"])
|
}, null, 8, ["disabled", "modelValue", "onUpdate:modelValue"])
|
||||||
]),
|
]),
|
||||||
_: 1
|
_: 1
|
||||||
})
|
})
|
||||||
@@ -114,4 +179,4 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
export {
|
export {
|
||||||
_sfc_main as default
|
_sfc_main as default
|
||||||
};
|
};
|
||||||
//# sourceMappingURL=ExtensionPanel-ByeZ01RF.js.map
|
//# sourceMappingURL=ExtensionPanel-3jWrm6Zi.js.map
|
||||||
1
web/assets/ExtensionPanel-ByeZ01RF.js.map
generated
vendored
1
web/assets/ExtensionPanel-ByeZ01RF.js.map
generated
vendored
@@ -1 +0,0 @@
|
|||||||
{"version":3,"file":"ExtensionPanel-ByeZ01RF.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Extension\" class=\"extension-panel\">\n <template #header>\n <SearchBox\n v-model=\"filters['global'].value\"\n :placeholder=\"$t('g.searchExtensions') + '...'\"\n />\n <Message v-if=\"hasChanges\" severity=\"info\" pt:text=\"w-full\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n <div class=\"flex justify-end\">\n <Button\n :label=\"$t('g.reloadToApplyChanges')\"\n @click=\"applyChanges\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n </template>\n <DataTable\n :value=\"extensionStore.extensions\"\n stripedRows\n size=\"small\"\n :filters=\"filters\"\n >\n <Column field=\"name\" :header=\"$t('g.extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport { FilterMatchMode } from '@primevue/core/api'\nimport PanelTemplate from './PanelTemplate.vue'\nimport SearchBox from '@/components/common/SearchBox.vue'\n\nconst filters = ref({\n global: { value: '', matchMode: FilterMatchMode.CONTAINS }\n})\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;AA8DA,UAAM,UAAU,IAAI;AAAA,MAClB,QAAQ,EAAE,OAAO,IAAI,WAAW,gBAAgB,SAAS;AAAA,IAAA,CAC1D;AAED,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
|
||||||
1
web/assets/GraphView-BWxgNrh6.js.map
generated
vendored
1
web/assets/GraphView-BWxgNrh6.js.map
generated
vendored
File diff suppressed because one or more lines are too long
7293
web/assets/GraphView-BWxgNrh6.js → web/assets/GraphView-CDDCHVO0.js
generated
vendored
7293
web/assets/GraphView-BWxgNrh6.js → web/assets/GraphView-CDDCHVO0.js
generated
vendored
File diff suppressed because one or more lines are too long
321
web/assets/GraphView-B3TpSwhZ.css → web/assets/GraphView-CqZ3opAX.css
generated
vendored
321
web/assets/GraphView-B3TpSwhZ.css → web/assets/GraphView-CqZ3opAX.css
generated
vendored
@@ -1,90 +1,33 @@
|
|||||||
|
|
||||||
.group-title-editor.node-title-editor[data-v-8a100d5a] {
|
.comfy-menu-hamburger[data-v-7ed57d1a] {
|
||||||
z-index: 9999;
|
pointer-events: auto;
|
||||||
padding: 0.25rem;
|
position: fixed;
|
||||||
}
|
z-index: 9999;
|
||||||
[data-v-8a100d5a] .editable-text {
|
display: flex;
|
||||||
width: 100%;
|
flex-direction: row
|
||||||
height: 100%;
|
|
||||||
}
|
|
||||||
[data-v-8a100d5a] .editable-text input {
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
/* Override the default font size */
|
|
||||||
font-size: inherit;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.side-bar-button-icon {
|
[data-v-e50caa15] .p-splitter-gutter {
|
||||||
font-size: var(--sidebar-icon-size) !important;
|
|
||||||
}
|
|
||||||
.side-bar-button-selected .side-bar-button-icon {
|
|
||||||
font-size: var(--sidebar-icon-size) !important;
|
|
||||||
font-weight: bold;
|
|
||||||
}
|
|
||||||
|
|
||||||
.side-bar-button[data-v-caa3ee9c] {
|
|
||||||
width: var(--sidebar-width);
|
|
||||||
height: var(--sidebar-width);
|
|
||||||
border-radius: 0;
|
|
||||||
}
|
|
||||||
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
|
||||||
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
|
||||||
border-left: 4px solid var(--p-button-text-primary-color);
|
|
||||||
}
|
|
||||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
|
||||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
|
||||||
border-right: 4px solid var(--p-button-text-primary-color);
|
|
||||||
}
|
|
||||||
|
|
||||||
:root {
|
|
||||||
--sidebar-width: 64px;
|
|
||||||
--sidebar-icon-size: 1.5rem;
|
|
||||||
}
|
|
||||||
:root .small-sidebar {
|
|
||||||
--sidebar-width: 40px;
|
|
||||||
--sidebar-icon-size: 1rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.side-tool-bar-container[data-v-7851c166] {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
align-items: center;
|
|
||||||
|
|
||||||
pointer-events: auto;
|
|
||||||
|
|
||||||
width: var(--sidebar-width);
|
|
||||||
height: 100%;
|
|
||||||
|
|
||||||
background-color: var(--comfy-menu-secondary-bg);
|
|
||||||
color: var(--fg-color);
|
|
||||||
box-shadow: var(--bar-shadow);
|
|
||||||
}
|
|
||||||
.side-tool-bar-end[data-v-7851c166] {
|
|
||||||
align-self: flex-end;
|
|
||||||
margin-top: auto;
|
|
||||||
}
|
|
||||||
|
|
||||||
[data-v-7c3279c1] .p-splitter-gutter {
|
|
||||||
pointer-events: auto;
|
pointer-events: auto;
|
||||||
}
|
}
|
||||||
[data-v-7c3279c1] .p-splitter-gutter:hover,[data-v-7c3279c1] .p-splitter-gutter[data-p-gutter-resizing='true'] {
|
[data-v-e50caa15] .p-splitter-gutter:hover,[data-v-e50caa15] .p-splitter-gutter[data-p-gutter-resizing='true'] {
|
||||||
transition: background-color 0.2s ease 300ms;
|
transition: background-color 0.2s ease 300ms;
|
||||||
background-color: var(--p-primary-color);
|
background-color: var(--p-primary-color);
|
||||||
}
|
}
|
||||||
.side-bar-panel[data-v-7c3279c1] {
|
.side-bar-panel[data-v-e50caa15] {
|
||||||
background-color: var(--bg-color);
|
background-color: var(--bg-color);
|
||||||
pointer-events: auto;
|
pointer-events: auto;
|
||||||
}
|
}
|
||||||
.bottom-panel[data-v-7c3279c1] {
|
.bottom-panel[data-v-e50caa15] {
|
||||||
background-color: var(--bg-color);
|
background-color: var(--bg-color);
|
||||||
pointer-events: auto;
|
pointer-events: auto;
|
||||||
}
|
}
|
||||||
.splitter-overlay[data-v-7c3279c1] {
|
.splitter-overlay[data-v-e50caa15] {
|
||||||
pointer-events: none;
|
pointer-events: none;
|
||||||
border-style: none;
|
border-style: none;
|
||||||
background-color: transparent;
|
background-color: transparent;
|
||||||
}
|
}
|
||||||
.splitter-overlay-root[data-v-7c3279c1] {
|
.splitter-overlay-root[data-v-e50caa15] {
|
||||||
position: absolute;
|
position: absolute;
|
||||||
top: 0px;
|
top: 0px;
|
||||||
left: 0px;
|
left: 0px;
|
||||||
@@ -98,7 +41,50 @@
|
|||||||
z-index: 999;
|
z-index: 999;
|
||||||
}
|
}
|
||||||
|
|
||||||
[data-v-d7cc0bce] .highlight {
|
.p-buttongroup-vertical[data-v-cb8f9a1a] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
border-radius: var(--p-button-border-radius);
|
||||||
|
overflow: hidden;
|
||||||
|
border: 1px solid var(--p-panel-border-color);
|
||||||
|
}
|
||||||
|
.p-buttongroup-vertical .p-button[data-v-cb8f9a1a] {
|
||||||
|
margin: 0;
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-tooltip[data-v-46859edf] {
|
||||||
|
background: var(--comfy-input-bg);
|
||||||
|
border-radius: 5px;
|
||||||
|
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||||
|
color: var(--input-text);
|
||||||
|
font-family: sans-serif;
|
||||||
|
left: 0;
|
||||||
|
max-width: 30vw;
|
||||||
|
padding: 4px 8px;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
transform: translate(5px, calc(-100% - 5px));
|
||||||
|
white-space: pre-wrap;
|
||||||
|
z-index: 99999;
|
||||||
|
}
|
||||||
|
|
||||||
|
.group-title-editor.node-title-editor[data-v-12d3fd12] {
|
||||||
|
z-index: 9999;
|
||||||
|
padding: 0.25rem;
|
||||||
|
}
|
||||||
|
[data-v-12d3fd12] .editable-text {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
[data-v-12d3fd12] .editable-text input {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
/* Override the default font size */
|
||||||
|
font-size: inherit;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-fd0a74bd] .highlight {
|
||||||
background-color: var(--p-primary-color);
|
background-color: var(--p-primary-color);
|
||||||
color: var(--p-primary-contrast-color);
|
color: var(--p-primary-contrast-color);
|
||||||
font-weight: bold;
|
font-weight: bold;
|
||||||
@@ -125,58 +111,55 @@
|
|||||||
align-items: flex-start !important;
|
align-items: flex-start !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.node-tooltip[data-v-9ecc8adc] {
|
.side-bar-button-icon {
|
||||||
background: var(--comfy-input-bg);
|
font-size: var(--sidebar-icon-size) !important;
|
||||||
border-radius: 5px;
|
}
|
||||||
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
.side-bar-button-selected .side-bar-button-icon {
|
||||||
color: var(--input-text);
|
font-size: var(--sidebar-icon-size) !important;
|
||||||
font-family: sans-serif;
|
font-weight: bold;
|
||||||
left: 0;
|
|
||||||
max-width: 30vw;
|
|
||||||
padding: 4px 8px;
|
|
||||||
position: absolute;
|
|
||||||
top: 0;
|
|
||||||
transform: translate(5px, calc(-100% - 5px));
|
|
||||||
white-space: pre-wrap;
|
|
||||||
z-index: 99999;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.p-buttongroup-vertical[data-v-94481f39] {
|
.side-bar-button[data-v-6ab4daa6] {
|
||||||
display: flex;
|
width: var(--sidebar-width);
|
||||||
flex-direction: column;
|
height: var(--sidebar-width);
|
||||||
border-radius: var(--p-button-border-radius);
|
|
||||||
overflow: hidden;
|
|
||||||
border: 1px solid var(--p-panel-border-color);
|
|
||||||
}
|
|
||||||
.p-buttongroup-vertical .p-button[data-v-94481f39] {
|
|
||||||
margin: 0;
|
|
||||||
border-radius: 0;
|
border-radius: 0;
|
||||||
}
|
}
|
||||||
|
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-6ab4daa6],
|
||||||
|
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-6ab4daa6]:hover {
|
||||||
|
border-left: 4px solid var(--p-button-text-primary-color);
|
||||||
|
}
|
||||||
|
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-6ab4daa6],
|
||||||
|
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-6ab4daa6]:hover {
|
||||||
|
border-right: 4px solid var(--p-button-text-primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-tool-bar-container[data-v-33cac83a] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
|
||||||
.comfy-menu-hamburger[data-v-962c4073] {
|
|
||||||
pointer-events: auto;
|
pointer-events: auto;
|
||||||
position: fixed;
|
|
||||||
z-index: 9999;
|
width: var(--sidebar-width);
|
||||||
|
height: 100%;
|
||||||
|
|
||||||
|
background-color: var(--comfy-menu-secondary-bg);
|
||||||
|
color: var(--fg-color);
|
||||||
|
box-shadow: var(--bar-shadow);
|
||||||
|
|
||||||
|
--sidebar-width: 4rem;
|
||||||
|
--sidebar-icon-size: 1.5rem;
|
||||||
|
}
|
||||||
|
.side-tool-bar-container.small-sidebar[data-v-33cac83a] {
|
||||||
|
--sidebar-width: 2.5rem;
|
||||||
|
--sidebar-icon-size: 1rem;
|
||||||
|
}
|
||||||
|
.side-tool-bar-end[data-v-33cac83a] {
|
||||||
|
align-self: flex-end;
|
||||||
|
margin-top: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
[data-v-4cb762cb] .p-togglebutton::before {
|
.status-indicator[data-v-8d011a31] {
|
||||||
display: none
|
|
||||||
}
|
|
||||||
[data-v-4cb762cb] .p-togglebutton {
|
|
||||||
position: relative;
|
|
||||||
flex-shrink: 0;
|
|
||||||
border-radius: 0px;
|
|
||||||
background-color: transparent;
|
|
||||||
padding: 0px
|
|
||||||
}
|
|
||||||
[data-v-4cb762cb] .p-togglebutton.p-togglebutton-checked {
|
|
||||||
border-bottom-width: 2px;
|
|
||||||
border-bottom-color: var(--p-button-text-primary-color)
|
|
||||||
}
|
|
||||||
[data-v-4cb762cb] .p-togglebutton-checked .close-button,[data-v-4cb762cb] .p-togglebutton:hover .close-button {
|
|
||||||
visibility: visible
|
|
||||||
}
|
|
||||||
.status-indicator[data-v-4cb762cb] {
|
|
||||||
position: absolute;
|
position: absolute;
|
||||||
font-weight: 700;
|
font-weight: 700;
|
||||||
font-size: 1.5rem;
|
font-size: 1.5rem;
|
||||||
@@ -184,62 +167,117 @@
|
|||||||
left: 50%;
|
left: 50%;
|
||||||
transform: translate(-50%, -50%)
|
transform: translate(-50%, -50%)
|
||||||
}
|
}
|
||||||
[data-v-4cb762cb] .p-togglebutton:hover .status-indicator {
|
|
||||||
|
[data-v-54fadc45] .p-togglebutton {
|
||||||
|
position: relative;
|
||||||
|
flex-shrink: 0;
|
||||||
|
border-radius: 0px;
|
||||||
|
border-width: 0px;
|
||||||
|
border-right-width: 1px;
|
||||||
|
border-style: solid;
|
||||||
|
background-color: transparent;
|
||||||
|
padding: 0px;
|
||||||
|
border-right-color: var(--border-color)
|
||||||
|
}
|
||||||
|
[data-v-54fadc45] .p-togglebutton::before {
|
||||||
display: none
|
display: none
|
||||||
}
|
}
|
||||||
[data-v-4cb762cb] .p-togglebutton .close-button {
|
[data-v-54fadc45] .p-togglebutton:first-child {
|
||||||
|
border-left-width: 1px;
|
||||||
|
border-style: solid;
|
||||||
|
border-left-color: var(--border-color)
|
||||||
|
}
|
||||||
|
[data-v-54fadc45] .p-togglebutton:not(:first-child) {
|
||||||
|
border-left-width: 0px
|
||||||
|
}
|
||||||
|
[data-v-54fadc45] .p-togglebutton.p-togglebutton-checked {
|
||||||
|
height: 100%;
|
||||||
|
border-bottom-width: 1px;
|
||||||
|
border-style: solid;
|
||||||
|
border-bottom-color: var(--p-button-text-primary-color)
|
||||||
|
}
|
||||||
|
[data-v-54fadc45] .p-togglebutton:not(.p-togglebutton-checked) {
|
||||||
|
opacity: 0.75
|
||||||
|
}
|
||||||
|
[data-v-54fadc45] .p-togglebutton-checked .close-button,[data-v-54fadc45] .p-togglebutton:hover .close-button {
|
||||||
|
visibility: visible
|
||||||
|
}
|
||||||
|
[data-v-54fadc45] .p-togglebutton:hover .status-indicator {
|
||||||
|
display: none
|
||||||
|
}
|
||||||
|
[data-v-54fadc45] .p-togglebutton .close-button {
|
||||||
visibility: hidden
|
visibility: hidden
|
||||||
}
|
}
|
||||||
|
[data-v-54fadc45] .p-scrollpanel-content {
|
||||||
.top-menubar[data-v-a2b12676] .p-menubar-item-link svg {
|
height: 100%
|
||||||
display: none;
|
|
||||||
}
|
|
||||||
[data-v-a2b12676] .p-menubar-submenu.dropdown-direction-up {
|
|
||||||
top: auto;
|
|
||||||
bottom: 100%;
|
|
||||||
flex-direction: column-reverse;
|
|
||||||
}
|
|
||||||
.keybinding-tag[data-v-a2b12676] {
|
|
||||||
background: var(--p-content-hover-background);
|
|
||||||
border-color: var(--p-content-border-color);
|
|
||||||
border-style: solid;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
[data-v-713442be] .p-inputtext {
|
/* Scrollbar half opacity to avoid blocking the active tab bottom border */
|
||||||
|
[data-v-54fadc45] .p-scrollpanel:hover .p-scrollpanel-bar,[data-v-54fadc45] .p-scrollpanel:active .p-scrollpanel-bar {
|
||||||
|
opacity: 0.5
|
||||||
|
}
|
||||||
|
[data-v-54fadc45] .p-selectbutton {
|
||||||
|
height: 100%;
|
||||||
|
border-radius: 0px
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-38831d8e] .workflow-tabs {
|
||||||
|
background-color: var(--comfy-menu-bg);
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-26957f1f] .p-inputtext {
|
||||||
border-top-left-radius: 0;
|
border-top-left-radius: 0;
|
||||||
border-bottom-left-radius: 0;
|
border-bottom-left-radius: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfyui-queue-button[data-v-d3897845] .p-splitbutton-dropdown {
|
.comfyui-queue-button[data-v-e9044686] .p-splitbutton-dropdown {
|
||||||
border-top-right-radius: 0;
|
border-top-right-radius: 0;
|
||||||
border-bottom-right-radius: 0;
|
border-bottom-right-radius: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.actionbar[data-v-542a7001] {
|
.actionbar[data-v-915e5456] {
|
||||||
pointer-events: all;
|
pointer-events: all;
|
||||||
position: fixed;
|
position: fixed;
|
||||||
z-index: 1000;
|
z-index: 1000;
|
||||||
}
|
}
|
||||||
.actionbar.is-docked[data-v-542a7001] {
|
.actionbar.is-docked[data-v-915e5456] {
|
||||||
position: static;
|
position: static;
|
||||||
border-style: none;
|
border-style: none;
|
||||||
background-color: transparent;
|
background-color: transparent;
|
||||||
padding: 0px;
|
padding: 0px;
|
||||||
}
|
}
|
||||||
.actionbar.is-dragging[data-v-542a7001] {
|
.actionbar.is-dragging[data-v-915e5456] {
|
||||||
-webkit-user-select: none;
|
-webkit-user-select: none;
|
||||||
-moz-user-select: none;
|
-moz-user-select: none;
|
||||||
user-select: none;
|
user-select: none;
|
||||||
}
|
}
|
||||||
[data-v-542a7001] .p-panel-content {
|
[data-v-915e5456] .p-panel-content {
|
||||||
padding: 0.25rem;
|
padding: 0.25rem;
|
||||||
}
|
}
|
||||||
[data-v-542a7001] .p-panel-header {
|
.is-docked[data-v-915e5456] .p-panel-content {
|
||||||
|
padding: 0px;
|
||||||
|
}
|
||||||
|
[data-v-915e5456] .p-panel-header {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfyui-menu[data-v-d792da31] {
|
.top-menubar[data-v-56df69d2] .p-menubar-item-link svg {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
[data-v-56df69d2] .p-menubar-submenu.dropdown-direction-up {
|
||||||
|
top: auto;
|
||||||
|
bottom: 100%;
|
||||||
|
flex-direction: column-reverse;
|
||||||
|
}
|
||||||
|
.keybinding-tag[data-v-56df69d2] {
|
||||||
|
background: var(--p-content-hover-background);
|
||||||
|
border-color: var(--p-content-border-color);
|
||||||
|
border-style: solid;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfyui-menu[data-v-6e35440f] {
|
||||||
width: 100vw;
|
width: 100vw;
|
||||||
|
height: var(--comfy-topbar-height);
|
||||||
background: var(--comfy-menu-bg);
|
background: var(--comfy-menu-bg);
|
||||||
color: var(--fg-color);
|
color: var(--fg-color);
|
||||||
box-shadow: var(--bar-shadow);
|
box-shadow: var(--bar-shadow);
|
||||||
@@ -249,18 +287,17 @@
|
|||||||
z-index: 1000;
|
z-index: 1000;
|
||||||
order: 0;
|
order: 0;
|
||||||
grid-column: 1/-1;
|
grid-column: 1/-1;
|
||||||
max-height: 90vh;
|
|
||||||
}
|
}
|
||||||
.comfyui-menu.dropzone[data-v-d792da31] {
|
.comfyui-menu.dropzone[data-v-6e35440f] {
|
||||||
background: var(--p-highlight-background);
|
background: var(--p-highlight-background);
|
||||||
}
|
}
|
||||||
.comfyui-menu.dropzone-active[data-v-d792da31] {
|
.comfyui-menu.dropzone-active[data-v-6e35440f] {
|
||||||
background: var(--p-highlight-background-focus);
|
background: var(--p-highlight-background-focus);
|
||||||
}
|
}
|
||||||
[data-v-d792da31] .p-menubar-item-label {
|
[data-v-6e35440f] .p-menubar-item-label {
|
||||||
line-height: revert;
|
line-height: revert;
|
||||||
}
|
}
|
||||||
.comfyui-logo[data-v-d792da31] {
|
.comfyui-logo[data-v-6e35440f] {
|
||||||
font-size: 1.2em;
|
font-size: 1.2em;
|
||||||
-webkit-user-select: none;
|
-webkit-user-select: none;
|
||||||
-moz-user-select: none;
|
-moz-user-select: none;
|
||||||
4
web/assets/InstallView-8N2LdZUx.css
generated
vendored
4
web/assets/InstallView-8N2LdZUx.css
generated
vendored
@@ -1,4 +0,0 @@
|
|||||||
|
|
||||||
[data-v-7ef01cf2] .p-steppanel {
|
|
||||||
background-color: transparent
|
|
||||||
}
|
|
||||||
1266
web/assets/InstallView-DbHtR5YG.js → web/assets/InstallView-By3hC1fC.js
generated
vendored
1266
web/assets/InstallView-DbHtR5YG.js → web/assets/InstallView-By3hC1fC.js
generated
vendored
File diff suppressed because one or more lines are too long
79
web/assets/InstallView-CxhfFC8Y.css
generated
vendored
Normal file
79
web/assets/InstallView-CxhfFC8Y.css
generated
vendored
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
|
||||||
|
.p-tag[data-v-79125ff6] {
|
||||||
|
--p-tag-gap: 0.5rem;
|
||||||
|
}
|
||||||
|
.hover-brighten[data-v-79125ff6] {
|
||||||
|
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke;
|
||||||
|
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||||
|
transition-duration: 150ms;
|
||||||
|
transition-property: filter, box-shadow;
|
||||||
|
&[data-v-79125ff6]:hover {
|
||||||
|
filter: brightness(107%) contrast(105%);
|
||||||
|
box-shadow: 0 0 0.25rem #ffffff79;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.p-accordioncontent-content[data-v-79125ff6] {
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
--tw-bg-opacity: 1;
|
||||||
|
background-color: rgb(23 23 23 / var(--tw-bg-opacity));
|
||||||
|
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke;
|
||||||
|
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||||
|
transition-duration: 150ms;
|
||||||
|
}
|
||||||
|
div.selected[data-v-79125ff6] {
|
||||||
|
.gpu-button[data-v-79125ff6]:not(.selected) {
|
||||||
|
opacity: 0.5;
|
||||||
|
}
|
||||||
|
.gpu-button[data-v-79125ff6]:not(.selected):hover {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.gpu-button[data-v-79125ff6] {
|
||||||
|
margin: 0px;
|
||||||
|
display: flex;
|
||||||
|
width: 50%;
|
||||||
|
cursor: pointer;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-around;
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
background-color: rgb(38 38 38 / var(--tw-bg-opacity));
|
||||||
|
--tw-bg-opacity: 0.5;
|
||||||
|
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke;
|
||||||
|
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||||
|
transition-duration: 150ms;
|
||||||
|
}
|
||||||
|
.gpu-button[data-v-79125ff6]:hover {
|
||||||
|
--tw-bg-opacity: 0.75;
|
||||||
|
}
|
||||||
|
.gpu-button[data-v-79125ff6] {
|
||||||
|
&.selected[data-v-79125ff6] {
|
||||||
|
--tw-bg-opacity: 1;
|
||||||
|
background-color: rgb(64 64 64 / var(--tw-bg-opacity));
|
||||||
|
}
|
||||||
|
&.selected[data-v-79125ff6] {
|
||||||
|
--tw-bg-opacity: 0.5;
|
||||||
|
}
|
||||||
|
&.selected[data-v-79125ff6] {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
&.selected[data-v-79125ff6]:hover {
|
||||||
|
--tw-bg-opacity: 0.6;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.disabled[data-v-79125ff6] {
|
||||||
|
pointer-events: none;
|
||||||
|
opacity: 0.4;
|
||||||
|
}
|
||||||
|
.p-card-header[data-v-79125ff6] {
|
||||||
|
flex-grow: 1;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.p-card-body[data-v-79125ff6] {
|
||||||
|
padding-top: 0px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-0a97b0ae] .p-steppanel {
|
||||||
|
background-color: transparent
|
||||||
|
}
|
||||||
1
web/assets/InstallView-DbHtR5YG.js.map
generated
vendored
1
web/assets/InstallView-DbHtR5YG.js.map
generated
vendored
File diff suppressed because one or more lines are too long
8
web/assets/KeybindingPanel-C3wT8hYZ.css
generated
vendored
8
web/assets/KeybindingPanel-C3wT8hYZ.css
generated
vendored
@@ -1,8 +0,0 @@
|
|||||||
|
|
||||||
[data-v-c20ad403] .p-datatable-tbody > tr > td {
|
|
||||||
padding: 0.25rem;
|
|
||||||
min-height: 2rem
|
|
||||||
}
|
|
||||||
[data-v-c20ad403] .p-datatable-row-selected .actions,[data-v-c20ad403] .p-datatable-selectable-row:hover .actions {
|
|
||||||
visibility: visible
|
|
||||||
}
|
|
||||||
28
web/assets/KeybindingPanel-DC2AxNNa.js → web/assets/KeybindingPanel-D6O16W_1.js
generated
vendored
28
web/assets/KeybindingPanel-DC2AxNNa.js → web/assets/KeybindingPanel-D6O16W_1.js
generated
vendored
@@ -1,8 +1,9 @@
|
|||||||
var __defProp = Object.defineProperty;
|
var __defProp = Object.defineProperty;
|
||||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
import { a as defineComponent, q as computed, f as openBlock, g as createElementBlock, P as Fragment, Q as renderList, h as createVNode, y as withCtx, ay as createTextVNode, a8 as toDisplayString, z as unref, aC as script, i as createCommentVNode, r as ref, ck as FilterMatchMode, O as useKeybindingStore, F as useCommandStore, I as useI18n, aS as normalizeI18nKey, aL as watchEffect, bn as useToast, t as resolveDirective, x as createBlock, cl as SearchBox, A as createBaseVNode, D as script$2, aq as script$4, br as withModifiers, bW as script$5, aI as script$6, v as withDirectives, cm as _sfc_main$2, R as pushScopeId, U as popScopeId, ce as KeyComboImpl, cn as KeybindingImpl, _ as _export_sfc } from "./index-DIU5yZe9.js";
|
import { d as defineComponent, c as computed, o as openBlock, f as createElementBlock, H as Fragment, I as renderList, k as createVNode, P as withCtx, aG as createTextVNode, Z as toDisplayString, j as unref, aK as script, L as createCommentVNode, ad as ref, cu as FilterMatchMode, a$ as useKeybindingStore, a4 as useCommandStore, a3 as useI18n, ah as normalizeI18nKey, w as watchEffect, bz as useToast, r as resolveDirective, J as createBlock, cv as SearchBox, m as createBaseVNode, l as script$2, ax as script$4, b3 as withModifiers, c6 as script$5, aP as script$6, i as withDirectives, cw as _sfc_main$2, p as pushScopeId, q as popScopeId, cx as KeyComboImpl, cy as KeybindingImpl, _ as _export_sfc } from "./index-QvfM__ze.js";
|
||||||
import { s as script$1, a as script$3 } from "./index-D3u7l7ha.js";
|
import { s as script$1, a as script$3 } from "./index-DpF-ptbJ.js";
|
||||||
import "./index-d698Brhb.js";
|
import { u as useKeybindingService } from "./keybindingService-Cak1En5n.js";
|
||||||
|
import "./index-Q1cQr26V.js";
|
||||||
const _hoisted_1$1 = {
|
const _hoisted_1$1 = {
|
||||||
key: 0,
|
key: 0,
|
||||||
class: "px-2"
|
class: "px-2"
|
||||||
@@ -35,7 +36,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-c20ad403"), n = n(), popScopeId(), n), "_withScopeId");
|
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-2554ab36"), n = n(), popScopeId(), n), "_withScopeId");
|
||||||
const _hoisted_1 = { class: "actions invisible flex flex-row" };
|
const _hoisted_1 = { class: "actions invisible flex flex-row" };
|
||||||
const _hoisted_2 = ["title"];
|
const _hoisted_2 = ["title"];
|
||||||
const _hoisted_3 = { key: 1 };
|
const _hoisted_3 = { key: 1 };
|
||||||
@@ -46,6 +47,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
|
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
|
||||||
});
|
});
|
||||||
const keybindingStore = useKeybindingStore();
|
const keybindingStore = useKeybindingStore();
|
||||||
|
const keybindingService = useKeybindingService();
|
||||||
const commandStore = useCommandStore();
|
const commandStore = useCommandStore();
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const commandsData = computed(() => {
|
const commandsData = computed(() => {
|
||||||
@@ -90,7 +92,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
function removeKeybinding(commandData) {
|
function removeKeybinding(commandData) {
|
||||||
if (commandData.keybinding) {
|
if (commandData.keybinding) {
|
||||||
keybindingStore.unsetKeybinding(commandData.keybinding);
|
keybindingStore.unsetKeybinding(commandData.keybinding);
|
||||||
keybindingStore.persistUserKeybindings();
|
keybindingService.persistUserKeybindings();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__name(removeKeybinding, "removeKeybinding");
|
__name(removeKeybinding, "removeKeybinding");
|
||||||
@@ -114,7 +116,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
if (updated) {
|
if (updated) {
|
||||||
keybindingStore.persistUserKeybindings();
|
keybindingService.persistUserKeybindings();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cancelEdit();
|
cancelEdit();
|
||||||
@@ -123,7 +125,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
async function resetKeybindings() {
|
async function resetKeybindings() {
|
||||||
keybindingStore.resetKeybindings();
|
keybindingStore.resetKeybindings();
|
||||||
await keybindingStore.persistUserKeybindings();
|
await keybindingService.persistUserKeybindings();
|
||||||
toast.add({
|
toast.add({
|
||||||
severity: "info",
|
severity: "info",
|
||||||
summary: "Info",
|
summary: "Info",
|
||||||
@@ -182,7 +184,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
}),
|
}),
|
||||||
createVNode(unref(script$1), {
|
createVNode(unref(script$1), {
|
||||||
field: "id",
|
field: "id",
|
||||||
header: "Command ID",
|
header: _ctx.$t("g.command"),
|
||||||
sortable: "",
|
sortable: "",
|
||||||
class: "max-w-64 2xl:max-w-full"
|
class: "max-w-64 2xl:max-w-full"
|
||||||
}, {
|
}, {
|
||||||
@@ -193,10 +195,10 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
}, toDisplayString(slotProps.data.label), 9, _hoisted_2)
|
}, toDisplayString(slotProps.data.label), 9, _hoisted_2)
|
||||||
]),
|
]),
|
||||||
_: 1
|
_: 1
|
||||||
}),
|
}, 8, ["header"]),
|
||||||
createVNode(unref(script$1), {
|
createVNode(unref(script$1), {
|
||||||
field: "keybinding",
|
field: "keybinding",
|
||||||
header: "Keybinding"
|
header: _ctx.$t("g.keybinding")
|
||||||
}, {
|
}, {
|
||||||
body: withCtx((slotProps) => [
|
body: withCtx((slotProps) => [
|
||||||
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
|
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
|
||||||
@@ -206,7 +208,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
|
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
|
||||||
]),
|
]),
|
||||||
_: 1
|
_: 1
|
||||||
})
|
}, 8, ["header"])
|
||||||
]),
|
]),
|
||||||
_: 1
|
_: 1
|
||||||
}, 8, ["value", "selection", "filters"]),
|
}, 8, ["value", "selection", "filters"]),
|
||||||
@@ -274,8 +276,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-c20ad403"]]);
|
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-2554ab36"]]);
|
||||||
export {
|
export {
|
||||||
KeybindingPanel as default
|
KeybindingPanel as default
|
||||||
};
|
};
|
||||||
//# sourceMappingURL=KeybindingPanel-DC2AxNNa.js.map
|
//# sourceMappingURL=KeybindingPanel-D6O16W_1.js.map
|
||||||
1
web/assets/KeybindingPanel-DC2AxNNa.js.map
generated
vendored
1
web/assets/KeybindingPanel-DC2AxNNa.js.map
generated
vendored
File diff suppressed because one or more lines are too long
8
web/assets/KeybindingPanel-DvrUYZ4S.css
generated
vendored
Normal file
8
web/assets/KeybindingPanel-DvrUYZ4S.css
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
[data-v-2554ab36] .p-datatable-tbody > tr > td {
|
||||||
|
padding: 0.25rem;
|
||||||
|
min-height: 2rem
|
||||||
|
}
|
||||||
|
[data-v-2554ab36] .p-datatable-row-selected .actions,[data-v-2554ab36] .p-datatable-selectable-row:hover .actions {
|
||||||
|
visibility: visible
|
||||||
|
}
|
||||||
7
web/assets/ManualConfigurationView-CsirlNfV.css
generated
vendored
Normal file
7
web/assets/ManualConfigurationView-CsirlNfV.css
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
|
||||||
|
.p-tag[data-v-dc169863] {
|
||||||
|
--p-tag-gap: 0.5rem;
|
||||||
|
}
|
||||||
|
.comfy-installer[data-v-dc169863] {
|
||||||
|
margin-top: max(1rem, max(0px, calc((100vh - 42rem) * 0.5)));
|
||||||
|
}
|
||||||
75
web/assets/ManualConfigurationView-enyqGo0M.js
generated
vendored
Normal file
75
web/assets/ManualConfigurationView-enyqGo0M.js
generated
vendored
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { d as defineComponent, a3 as useI18n, ad as ref, t as onMounted, o as openBlock, J as createBlock, P as withCtx, m as createBaseVNode, Z as toDisplayString, k as createVNode, j as unref, aK as script, bN as script$1, l as script$2, p as pushScopeId, q as popScopeId, bV as electronAPI, _ as _export_sfc } from "./index-QvfM__ze.js";
|
||||||
|
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BhQMaVFP.js";
|
||||||
|
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-dc169863"), n = n(), popScopeId(), n), "_withScopeId");
|
||||||
|
const _hoisted_1 = { class: "comfy-installer grow flex flex-col gap-4 text-neutral-300 max-w-110" };
|
||||||
|
const _hoisted_2 = { class: "text-2xl font-semibold text-neutral-100" };
|
||||||
|
const _hoisted_3 = { class: "m-1 text-neutral-300" };
|
||||||
|
const _hoisted_4 = { class: "ml-2" };
|
||||||
|
const _hoisted_5 = { class: "m-1 mb-4" };
|
||||||
|
const _hoisted_6 = { class: "m-0" };
|
||||||
|
const _hoisted_7 = { class: "m-1" };
|
||||||
|
const _hoisted_8 = { class: "font-mono" };
|
||||||
|
const _hoisted_9 = { class: "m-1" };
|
||||||
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "ManualConfigurationView",
|
||||||
|
setup(__props) {
|
||||||
|
const { t } = useI18n();
|
||||||
|
const electron = electronAPI();
|
||||||
|
const basePath = ref(null);
|
||||||
|
const sep = ref("/");
|
||||||
|
const restartApp = /* @__PURE__ */ __name((message) => electron.restartApp(message), "restartApp");
|
||||||
|
onMounted(async () => {
|
||||||
|
basePath.value = await electron.getBasePath();
|
||||||
|
if (basePath.value.indexOf("/") === -1) sep.value = "\\";
|
||||||
|
});
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
return openBlock(), createBlock(_sfc_main$1, { dark: "" }, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createBaseVNode("div", _hoisted_1, [
|
||||||
|
createBaseVNode("h2", _hoisted_2, toDisplayString(_ctx.$t("install.manualConfiguration.title")), 1),
|
||||||
|
createBaseVNode("p", _hoisted_3, [
|
||||||
|
createVNode(unref(script), {
|
||||||
|
icon: "pi pi-exclamation-triangle",
|
||||||
|
severity: "warn",
|
||||||
|
value: unref(t)("icon.exclamation-triangle")
|
||||||
|
}, null, 8, ["value"]),
|
||||||
|
createBaseVNode("strong", _hoisted_4, toDisplayString(_ctx.$t("install.gpuSelection.customComfyNeedsPython")), 1)
|
||||||
|
]),
|
||||||
|
createBaseVNode("div", null, [
|
||||||
|
createBaseVNode("p", _hoisted_5, toDisplayString(_ctx.$t("install.manualConfiguration.requirements")) + ": ", 1),
|
||||||
|
createBaseVNode("ul", _hoisted_6, [
|
||||||
|
createBaseVNode("li", null, toDisplayString(_ctx.$t("install.gpuSelection.customManualVenv")), 1),
|
||||||
|
createBaseVNode("li", null, toDisplayString(_ctx.$t("install.gpuSelection.customInstallRequirements")), 1)
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
createBaseVNode("p", _hoisted_7, toDisplayString(_ctx.$t("install.manualConfiguration.createVenv")) + ":", 1),
|
||||||
|
createVNode(unref(script$1), {
|
||||||
|
header: unref(t)("install.manualConfiguration.virtualEnvironmentPath")
|
||||||
|
}, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createBaseVNode("span", _hoisted_8, toDisplayString(`${basePath.value}${sep.value}.venv${sep.value}`), 1)
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
}, 8, ["header"]),
|
||||||
|
createBaseVNode("p", _hoisted_9, toDisplayString(_ctx.$t("install.manualConfiguration.restartWhenFinished")), 1),
|
||||||
|
createVNode(unref(script$2), {
|
||||||
|
class: "place-self-end",
|
||||||
|
label: unref(t)("menuLabels.Restart"),
|
||||||
|
severity: "warn",
|
||||||
|
icon: "pi pi-refresh",
|
||||||
|
onClick: _cache[0] || (_cache[0] = ($event) => restartApp("Manual configuration complete"))
|
||||||
|
}, null, 8, ["label"])
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const ManualConfigurationView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-dc169863"]]);
|
||||||
|
export {
|
||||||
|
ManualConfigurationView as default
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=ManualConfigurationView-enyqGo0M.js.map
|
||||||
86
web/assets/MetricsConsentView-lSfLu4nr.js
generated
vendored
Normal file
86
web/assets/MetricsConsentView-lSfLu4nr.js
generated
vendored
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BhQMaVFP.js";
|
||||||
|
import { d as defineComponent, bz as useToast, a3 as useI18n, ad as ref, c2 as useRouter, o as openBlock, J as createBlock, P as withCtx, m as createBaseVNode, Z as toDisplayString, aG as createTextVNode, k as createVNode, j as unref, cc as script, l as script$1, bV as electronAPI } from "./index-QvfM__ze.js";
|
||||||
|
const _hoisted_1 = { class: "h-full p-8 2xl:p-16 flex flex-col items-center justify-center" };
|
||||||
|
const _hoisted_2 = { class: "bg-neutral-800 rounded-lg shadow-lg p-6 w-full max-w-[600px] flex flex-col gap-6" };
|
||||||
|
const _hoisted_3 = { class: "text-3xl font-semibold text-neutral-100" };
|
||||||
|
const _hoisted_4 = { class: "text-neutral-400" };
|
||||||
|
const _hoisted_5 = { class: "text-neutral-400" };
|
||||||
|
const _hoisted_6 = {
|
||||||
|
href: "https://comfy.org/privacy",
|
||||||
|
target: "_blank",
|
||||||
|
class: "text-blue-400 hover:text-blue-300 underline"
|
||||||
|
};
|
||||||
|
const _hoisted_7 = { class: "flex items-center gap-4" };
|
||||||
|
const _hoisted_8 = {
|
||||||
|
id: "metricsDescription",
|
||||||
|
class: "text-neutral-100"
|
||||||
|
};
|
||||||
|
const _hoisted_9 = { class: "flex pt-6 justify-end" };
|
||||||
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "MetricsConsentView",
|
||||||
|
setup(__props) {
|
||||||
|
const toast = useToast();
|
||||||
|
const { t } = useI18n();
|
||||||
|
const allowMetrics = ref(true);
|
||||||
|
const router = useRouter();
|
||||||
|
const isUpdating = ref(false);
|
||||||
|
const updateConsent = /* @__PURE__ */ __name(async () => {
|
||||||
|
isUpdating.value = true;
|
||||||
|
try {
|
||||||
|
await electronAPI().setMetricsConsent(allowMetrics.value);
|
||||||
|
} catch (error) {
|
||||||
|
toast.add({
|
||||||
|
severity: "error",
|
||||||
|
summary: t("install.errorUpdatingConsent"),
|
||||||
|
detail: t("install.errorUpdatingConsentDetail"),
|
||||||
|
life: 3e3
|
||||||
|
});
|
||||||
|
} finally {
|
||||||
|
isUpdating.value = false;
|
||||||
|
}
|
||||||
|
router.push("/");
|
||||||
|
}, "updateConsent");
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
const _component_BaseViewTemplate = _sfc_main$1;
|
||||||
|
return openBlock(), createBlock(_component_BaseViewTemplate, { dark: "" }, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createBaseVNode("div", _hoisted_1, [
|
||||||
|
createBaseVNode("div", _hoisted_2, [
|
||||||
|
createBaseVNode("h2", _hoisted_3, toDisplayString(_ctx.$t("install.helpImprove")), 1),
|
||||||
|
createBaseVNode("p", _hoisted_4, toDisplayString(_ctx.$t("install.updateConsent")), 1),
|
||||||
|
createBaseVNode("p", _hoisted_5, [
|
||||||
|
createTextVNode(toDisplayString(_ctx.$t("install.moreInfo")) + " ", 1),
|
||||||
|
createBaseVNode("a", _hoisted_6, toDisplayString(_ctx.$t("install.privacyPolicy")), 1),
|
||||||
|
createTextVNode(". ")
|
||||||
|
]),
|
||||||
|
createBaseVNode("div", _hoisted_7, [
|
||||||
|
createVNode(unref(script), {
|
||||||
|
modelValue: allowMetrics.value,
|
||||||
|
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => allowMetrics.value = $event),
|
||||||
|
"aria-describedby": "metricsDescription"
|
||||||
|
}, null, 8, ["modelValue"]),
|
||||||
|
createBaseVNode("span", _hoisted_8, toDisplayString(allowMetrics.value ? _ctx.$t("install.metricsEnabled") : _ctx.$t("install.metricsDisabled")), 1)
|
||||||
|
]),
|
||||||
|
createBaseVNode("div", _hoisted_9, [
|
||||||
|
createVNode(unref(script$1), {
|
||||||
|
label: _ctx.$t("g.ok"),
|
||||||
|
icon: "pi pi-check",
|
||||||
|
loading: isUpdating.value,
|
||||||
|
iconPos: "right",
|
||||||
|
onClick: updateConsent
|
||||||
|
}, null, 8, ["label", "loading"])
|
||||||
|
])
|
||||||
|
])
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
export {
|
||||||
|
_sfc_main as default
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=MetricsConsentView-lSfLu4nr.js.map
|
||||||
82
web/assets/NotSupportedView-C8O1Ed5c.js
generated
vendored
82
web/assets/NotSupportedView-C8O1Ed5c.js
generated
vendored
@@ -1,82 +0,0 @@
|
|||||||
var __defProp = Object.defineProperty;
|
|
||||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
|
||||||
import { a as defineComponent, bU as useRouter, t as resolveDirective, f as openBlock, g as createElementBlock, A as createBaseVNode, a8 as toDisplayString, h as createVNode, z as unref, D as script, v as withDirectives } from "./index-DIU5yZe9.js";
|
|
||||||
const _imports_0 = "" + new URL("images/sad_girl.png", import.meta.url).href;
|
|
||||||
const _hoisted_1 = { class: "font-sans w-screen h-screen flex items-center m-0 text-neutral-900 bg-neutral-300 pointer-events-auto" };
|
|
||||||
const _hoisted_2 = { class: "flex-grow flex items-center justify-center" };
|
|
||||||
const _hoisted_3 = { class: "flex flex-col gap-8 p-8" };
|
|
||||||
const _hoisted_4 = { class: "text-4xl font-bold text-red-500" };
|
|
||||||
const _hoisted_5 = { class: "space-y-4" };
|
|
||||||
const _hoisted_6 = { class: "text-xl" };
|
|
||||||
const _hoisted_7 = { class: "list-disc list-inside space-y-1 text-neutral-800" };
|
|
||||||
const _hoisted_8 = { class: "flex gap-4" };
|
|
||||||
const _hoisted_9 = /* @__PURE__ */ createBaseVNode("div", { class: "h-screen flex-grow-0" }, [
|
|
||||||
/* @__PURE__ */ createBaseVNode("img", {
|
|
||||||
src: _imports_0,
|
|
||||||
alt: "Sad girl illustration",
|
|
||||||
class: "h-full object-cover"
|
|
||||||
})
|
|
||||||
], -1);
|
|
||||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
|
||||||
__name: "NotSupportedView",
|
|
||||||
setup(__props) {
|
|
||||||
const openDocs = /* @__PURE__ */ __name(() => {
|
|
||||||
window.open(
|
|
||||||
"https://github.com/Comfy-Org/desktop#currently-supported-platforms",
|
|
||||||
"_blank"
|
|
||||||
);
|
|
||||||
}, "openDocs");
|
|
||||||
const reportIssue = /* @__PURE__ */ __name(() => {
|
|
||||||
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
|
|
||||||
}, "reportIssue");
|
|
||||||
const router = useRouter();
|
|
||||||
const continueToInstall = /* @__PURE__ */ __name(() => {
|
|
||||||
router.push("/install");
|
|
||||||
}, "continueToInstall");
|
|
||||||
return (_ctx, _cache) => {
|
|
||||||
const _directive_tooltip = resolveDirective("tooltip");
|
|
||||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
|
||||||
createBaseVNode("div", _hoisted_2, [
|
|
||||||
createBaseVNode("div", _hoisted_3, [
|
|
||||||
createBaseVNode("h1", _hoisted_4, toDisplayString(_ctx.$t("notSupported.title")), 1),
|
|
||||||
createBaseVNode("div", _hoisted_5, [
|
|
||||||
createBaseVNode("p", _hoisted_6, toDisplayString(_ctx.$t("notSupported.message")), 1),
|
|
||||||
createBaseVNode("ul", _hoisted_7, [
|
|
||||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("notSupported.supportedDevices.macos")), 1),
|
|
||||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("notSupported.supportedDevices.windows")), 1)
|
|
||||||
])
|
|
||||||
]),
|
|
||||||
createBaseVNode("div", _hoisted_8, [
|
|
||||||
createVNode(unref(script), {
|
|
||||||
label: _ctx.$t("notSupported.learnMore"),
|
|
||||||
icon: "pi pi-github",
|
|
||||||
onClick: openDocs,
|
|
||||||
severity: "secondary"
|
|
||||||
}, null, 8, ["label"]),
|
|
||||||
createVNode(unref(script), {
|
|
||||||
label: _ctx.$t("notSupported.reportIssue"),
|
|
||||||
icon: "pi pi-flag",
|
|
||||||
onClick: reportIssue,
|
|
||||||
severity: "secondary"
|
|
||||||
}, null, 8, ["label"]),
|
|
||||||
withDirectives(createVNode(unref(script), {
|
|
||||||
label: _ctx.$t("notSupported.continue"),
|
|
||||||
icon: "pi pi-arrow-right",
|
|
||||||
iconPos: "right",
|
|
||||||
onClick: continueToInstall,
|
|
||||||
severity: "danger"
|
|
||||||
}, null, 8, ["label"]), [
|
|
||||||
[_directive_tooltip, _ctx.$t("notSupported.continueTooltip")]
|
|
||||||
])
|
|
||||||
])
|
|
||||||
])
|
|
||||||
]),
|
|
||||||
_hoisted_9
|
|
||||||
]);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
export {
|
|
||||||
_sfc_main as default
|
|
||||||
};
|
|
||||||
//# sourceMappingURL=NotSupportedView-C8O1Ed5c.js.map
|
|
||||||
1
web/assets/NotSupportedView-C8O1Ed5c.js.map
generated
vendored
1
web/assets/NotSupportedView-C8O1Ed5c.js.map
generated
vendored
@@ -1 +0,0 @@
|
|||||||
{"version":3,"file":"NotSupportedView-C8O1Ed5c.js","sources":["../../../../../../../assets/images/sad_girl.png","../../src/views/NotSupportedView.vue"],"sourcesContent":["export default \"__VITE_PUBLIC_ASSET__b82952e7__\"","<template>\n <div\n class=\"font-sans w-screen h-screen flex items-center m-0 text-neutral-900 bg-neutral-300 pointer-events-auto\"\n >\n <div class=\"flex-grow flex items-center justify-center\">\n <div class=\"flex flex-col gap-8 p-8\">\n <!-- Header -->\n <h1 class=\"text-4xl font-bold text-red-500\">\n {{ $t('notSupported.title') }}\n </h1>\n\n <!-- Message -->\n <div class=\"space-y-4\">\n <p class=\"text-xl\">\n {{ $t('notSupported.message') }}\n </p>\n <ul class=\"list-disc list-inside space-y-1 text-neutral-800\">\n <li>{{ $t('notSupported.supportedDevices.macos') }}</li>\n <li>{{ $t('notSupported.supportedDevices.windows') }}</li>\n </ul>\n </div>\n\n <!-- Actions -->\n <div class=\"flex gap-4\">\n <Button\n :label=\"$t('notSupported.learnMore')\"\n icon=\"pi pi-github\"\n @click=\"openDocs\"\n severity=\"secondary\"\n />\n <Button\n :label=\"$t('notSupported.reportIssue')\"\n icon=\"pi pi-flag\"\n @click=\"reportIssue\"\n severity=\"secondary\"\n />\n <Button\n :label=\"$t('notSupported.continue')\"\n icon=\"pi pi-arrow-right\"\n iconPos=\"right\"\n @click=\"continueToInstall\"\n severity=\"danger\"\n v-tooltip=\"$t('notSupported.continueTooltip')\"\n />\n </div>\n </div>\n </div>\n\n <!-- Right side image -->\n <div class=\"h-screen flex-grow-0\">\n <img\n src=\"/assets/images/sad_girl.png\"\n alt=\"Sad girl illustration\"\n class=\"h-full object-cover\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { useRouter } from 'vue-router'\n\nconst openDocs = () => {\n window.open(\n 'https://github.com/Comfy-Org/desktop#currently-supported-platforms',\n '_blank'\n )\n}\n\nconst reportIssue = () => {\n window.open('https://forum.comfy.org/c/v1-feedback/', '_blank')\n}\n\nconst router = useRouter()\nconst continueToInstall = () => {\n router.push('/install')\n}\n</script>\n"],"names":[],"mappings":";;;AAAA,MAAe,aAAA,KAAA,IAAA,IAAA,uBAAA,YAAA,GAAA,EAAA;;;;;;;;;;;;;;;;;;;AC+Df,UAAM,WAAW,6BAAM;AACd,aAAA;AAAA,QACL;AAAA,QACA;AAAA,MAAA;AAAA,IACF,GAJe;AAOjB,UAAM,cAAc,6BAAM;AACjB,aAAA,KAAK,0CAA0C,QAAQ;AAAA,IAAA,GAD5C;AAIpB,UAAM,SAAS;AACf,UAAM,oBAAoB,6BAAM;AAC9B,aAAO,KAAK,UAAU;AAAA,IAAA,GADE;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
|
||||||
17
web/assets/NotSupportedView-DQerxQzi.css
generated
vendored
Normal file
17
web/assets/NotSupportedView-DQerxQzi.css
generated
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
|
||||||
|
.sad-container[data-v-ebb20958] {
|
||||||
|
display: grid;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-evenly;
|
||||||
|
grid-template-columns: 25rem 1fr;
|
||||||
|
&[data-v-ebb20958] > * {
|
||||||
|
grid-row: 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.sad-text[data-v-ebb20958] {
|
||||||
|
grid-column: 1/3;
|
||||||
|
}
|
||||||
|
.sad-girl[data-v-ebb20958] {
|
||||||
|
grid-column: 2/3;
|
||||||
|
width: min(75vw, 100vh);
|
||||||
|
}
|
||||||
88
web/assets/NotSupportedView-Vc8_xWgH.js
generated
vendored
Normal file
88
web/assets/NotSupportedView-Vc8_xWgH.js
generated
vendored
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { d as defineComponent, c2 as useRouter, r as resolveDirective, o as openBlock, J as createBlock, P as withCtx, m as createBaseVNode, Z as toDisplayString, k as createVNode, j as unref, l as script, i as withDirectives, p as pushScopeId, q as popScopeId, _ as _export_sfc } from "./index-QvfM__ze.js";
|
||||||
|
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BhQMaVFP.js";
|
||||||
|
const _imports_0 = "" + new URL("images/sad_girl.png", import.meta.url).href;
|
||||||
|
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-ebb20958"), n = n(), popScopeId(), n), "_withScopeId");
|
||||||
|
const _hoisted_1 = { class: "sad-container" };
|
||||||
|
const _hoisted_2 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("img", {
|
||||||
|
class: "sad-girl",
|
||||||
|
src: _imports_0,
|
||||||
|
alt: "Sad girl illustration"
|
||||||
|
}, null, -1));
|
||||||
|
const _hoisted_3 = { class: "no-drag sad-text flex items-center" };
|
||||||
|
const _hoisted_4 = { class: "flex flex-col gap-8 p-8 min-w-110" };
|
||||||
|
const _hoisted_5 = { class: "text-4xl font-bold text-red-500" };
|
||||||
|
const _hoisted_6 = { class: "space-y-4" };
|
||||||
|
const _hoisted_7 = { class: "text-xl" };
|
||||||
|
const _hoisted_8 = { class: "list-disc list-inside space-y-1 text-neutral-800" };
|
||||||
|
const _hoisted_9 = { class: "flex gap-4" };
|
||||||
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "NotSupportedView",
|
||||||
|
setup(__props) {
|
||||||
|
const openDocs = /* @__PURE__ */ __name(() => {
|
||||||
|
window.open(
|
||||||
|
"https://github.com/Comfy-Org/desktop#currently-supported-platforms",
|
||||||
|
"_blank"
|
||||||
|
);
|
||||||
|
}, "openDocs");
|
||||||
|
const reportIssue = /* @__PURE__ */ __name(() => {
|
||||||
|
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
|
||||||
|
}, "reportIssue");
|
||||||
|
const router = useRouter();
|
||||||
|
const continueToInstall = /* @__PURE__ */ __name(() => {
|
||||||
|
router.push("/install");
|
||||||
|
}, "continueToInstall");
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
const _directive_tooltip = resolveDirective("tooltip");
|
||||||
|
return openBlock(), createBlock(_sfc_main$1, null, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createBaseVNode("div", _hoisted_1, [
|
||||||
|
_hoisted_2,
|
||||||
|
createBaseVNode("div", _hoisted_3, [
|
||||||
|
createBaseVNode("div", _hoisted_4, [
|
||||||
|
createBaseVNode("h1", _hoisted_5, toDisplayString(_ctx.$t("notSupported.title")), 1),
|
||||||
|
createBaseVNode("div", _hoisted_6, [
|
||||||
|
createBaseVNode("p", _hoisted_7, toDisplayString(_ctx.$t("notSupported.message")), 1),
|
||||||
|
createBaseVNode("ul", _hoisted_8, [
|
||||||
|
createBaseVNode("li", null, toDisplayString(_ctx.$t("notSupported.supportedDevices.macos")), 1),
|
||||||
|
createBaseVNode("li", null, toDisplayString(_ctx.$t("notSupported.supportedDevices.windows")), 1)
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
createBaseVNode("div", _hoisted_9, [
|
||||||
|
createVNode(unref(script), {
|
||||||
|
label: _ctx.$t("notSupported.learnMore"),
|
||||||
|
icon: "pi pi-github",
|
||||||
|
onClick: openDocs,
|
||||||
|
severity: "secondary"
|
||||||
|
}, null, 8, ["label"]),
|
||||||
|
createVNode(unref(script), {
|
||||||
|
label: _ctx.$t("notSupported.reportIssue"),
|
||||||
|
icon: "pi pi-flag",
|
||||||
|
onClick: reportIssue,
|
||||||
|
severity: "secondary"
|
||||||
|
}, null, 8, ["label"]),
|
||||||
|
withDirectives(createVNode(unref(script), {
|
||||||
|
label: _ctx.$t("notSupported.continue"),
|
||||||
|
icon: "pi pi-arrow-right",
|
||||||
|
iconPos: "right",
|
||||||
|
onClick: continueToInstall,
|
||||||
|
severity: "danger"
|
||||||
|
}, null, 8, ["label"]), [
|
||||||
|
[_directive_tooltip, _ctx.$t("notSupported.continueTooltip")]
|
||||||
|
])
|
||||||
|
])
|
||||||
|
])
|
||||||
|
])
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const NotSupportedView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-ebb20958"]]);
|
||||||
|
export {
|
||||||
|
NotSupportedView as default
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=NotSupportedView-Vc8_xWgH.js.map
|
||||||
8
web/assets/ServerConfigPanel-CvXC1Xmx.js → web/assets/ServerConfigPanel-B-w0HFlz.js
generated
vendored
8
web/assets/ServerConfigPanel-CvXC1Xmx.js → web/assets/ServerConfigPanel-B-w0HFlz.js
generated
vendored
@@ -1,7 +1,7 @@
|
|||||||
var __defProp = Object.defineProperty;
|
var __defProp = Object.defineProperty;
|
||||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
import { A as createBaseVNode, f as openBlock, g as createElementBlock, aZ as markRaw, a as defineComponent, u as useSettingStore, aK as storeToRefs, w as watch, cL as useCopyToClipboard, I as useI18n, x as createBlock, y as withCtx, z as unref, bW as script, a8 as toDisplayString, Q as renderList, P as Fragment, h as createVNode, D as script$1, i as createCommentVNode, bN as script$2, cM as FormItem, cm as _sfc_main$1, bZ as electronAPI } from "./index-DIU5yZe9.js";
|
import { m as createBaseVNode, o as openBlock, f as createElementBlock, a0 as markRaw, d as defineComponent, a as useSettingStore, aS as storeToRefs, a7 as watch, cW as useCopyToClipboard, a3 as useI18n, J as createBlock, P as withCtx, j as unref, c6 as script, Z as toDisplayString, I as renderList, H as Fragment, k as createVNode, l as script$1, L as createCommentVNode, c4 as script$2, cX as FormItem, cw as _sfc_main$1, bV as electronAPI } from "./index-QvfM__ze.js";
|
||||||
import { u as useServerConfigStore } from "./serverConfigStore-DYv7_Nld.js";
|
import { u as useServerConfigStore } from "./serverConfigStore-DCme3xlV.js";
|
||||||
const _hoisted_1$1 = {
|
const _hoisted_1$1 = {
|
||||||
viewBox: "0 0 24 24",
|
viewBox: "0 0 24 24",
|
||||||
width: "1.2em",
|
width: "1.2em",
|
||||||
@@ -131,7 +131,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
(openBlock(true), createElementBlock(Fragment, null, renderList(items, (item) => {
|
(openBlock(true), createElementBlock(Fragment, null, renderList(items, (item) => {
|
||||||
return openBlock(), createElementBlock("div", {
|
return openBlock(), createElementBlock("div", {
|
||||||
key: item.name,
|
key: item.name,
|
||||||
class: "flex items-center mb-4"
|
class: "mb-4"
|
||||||
}, [
|
}, [
|
||||||
createVNode(FormItem, {
|
createVNode(FormItem, {
|
||||||
item: translateItem(item),
|
item: translateItem(item),
|
||||||
@@ -155,4 +155,4 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
export {
|
export {
|
||||||
_sfc_main as default
|
_sfc_main as default
|
||||||
};
|
};
|
||||||
//# sourceMappingURL=ServerConfigPanel-CvXC1Xmx.js.map
|
//# sourceMappingURL=ServerConfigPanel-B-w0HFlz.js.map
|
||||||
1
web/assets/ServerConfigPanel-CvXC1Xmx.js.map
generated
vendored
1
web/assets/ServerConfigPanel-CvXC1Xmx.js.map
generated
vendored
@@ -1 +0,0 @@
|
|||||||
{"version":3,"file":"ServerConfigPanel-CvXC1Xmx.js","sources":["../../src/components/dialog/content/setting/ServerConfigPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Server-Config\" class=\"server-config-panel\">\n <template #header>\n <div class=\"flex flex-col gap-2\">\n <Message\n v-if=\"modifiedConfigs.length > 0\"\n severity=\"info\"\n pt:text=\"w-full\"\n >\n <p>\n {{ $t('serverConfig.modifiedConfigs') }}\n </p>\n <ul>\n <li v-for=\"config in modifiedConfigs\" :key=\"config.id\">\n {{ config.name }}: {{ config.initialValue }} → {{ config.value }}\n </li>\n </ul>\n <div class=\"flex justify-end gap-2\">\n <Button\n :label=\"$t('serverConfig.revertChanges')\"\n @click=\"revertChanges\"\n outlined\n />\n <Button\n :label=\"$t('serverConfig.restart')\"\n @click=\"restartApp\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n <Message v-if=\"commandLineArgs\" severity=\"secondary\" pt:text=\"w-full\">\n <template #icon>\n <i-lucide:terminal class=\"text-xl font-bold\" />\n </template>\n <div class=\"flex items-center justify-between\">\n <p>{{ commandLineArgs }}</p>\n <Button\n icon=\"pi pi-clipboard\"\n @click=\"copyCommandLineArgs\"\n severity=\"secondary\"\n text\n />\n </div>\n </Message>\n </div>\n </template>\n <div\n v-for=\"([label, items], i) in Object.entries(serverConfigsByCategory)\"\n :key=\"label\"\n >\n <Divider v-if=\"i > 0\" />\n <h3>{{ $t(`serverConfigCategories.${label}`, label) }}</h3>\n <div\n v-for=\"item in items\"\n :key=\"item.name\"\n class=\"flex items-center mb-4\"\n >\n <FormItem\n :item=\"translateItem(item)\"\n v-model:formValue=\"item.value\"\n :id=\"item.id\"\n :labelClass=\"{\n 'text-highlight': item.initialValue !== item.value\n }\"\n />\n </div>\n </div>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport Divider from 'primevue/divider'\nimport FormItem from '@/components/common/FormItem.vue'\nimport PanelTemplate from './PanelTemplate.vue'\nimport { useServerConfigStore } from '@/stores/serverConfigStore'\nimport { storeToRefs } from 'pinia'\nimport { electronAPI } from '@/utils/envUtil'\nimport { useSettingStore } from '@/stores/settingStore'\nimport { watch } from 'vue'\nimport { useCopyToClipboard } from '@/hooks/clipboardHooks'\nimport type { FormItem as FormItemType } from '@/types/settingTypes'\nimport type { ServerConfig } from '@/constants/serverConfig'\nimport { useI18n } from 'vue-i18n'\n\nconst settingStore = useSettingStore()\nconst serverConfigStore = useServerConfigStore()\nconst {\n serverConfigsByCategory,\n serverConfigValues,\n launchArgs,\n commandLineArgs,\n modifiedConfigs\n} = storeToRefs(serverConfigStore)\n\nconst revertChanges = () => {\n serverConfigStore.revertChanges()\n}\n\nconst restartApp = () => {\n electronAPI().restartApp()\n}\n\nwatch(launchArgs, (newVal) => {\n settingStore.set('Comfy.Server.LaunchArgs', newVal)\n})\n\nwatch(serverConfigValues, (newVal) => {\n settingStore.set('Comfy.Server.ServerConfigValues', newVal)\n})\n\nconst { copyToClipboard } = useCopyToClipboard()\nconst copyCommandLineArgs = async () => {\n await copyToClipboard(commandLineArgs.value)\n}\n\nconst { t } = useI18n()\nconst translateItem = (item: ServerConfig<any>): FormItemType => {\n return {\n ...item,\n name: t(`serverConfigItems.${item.id}.name`, item.name),\n tooltip: item.tooltip\n ? t(`serverConfigItems.${item.id}.tooltip`, item.tooltip)\n : undefined\n }\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAuFA,UAAM,eAAe;AACrB,UAAM,oBAAoB;AACpB,UAAA;AAAA,MACJ;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,IAAA,IACE,YAAY,iBAAiB;AAEjC,UAAM,gBAAgB,6BAAM;AAC1B,wBAAkB,cAAc;AAAA,IAAA,GADZ;AAItB,UAAM,aAAa,6BAAM;AACvB,kBAAA,EAAc;IAAW,GADR;AAIb,UAAA,YAAY,CAAC,WAAW;AACf,mBAAA,IAAI,2BAA2B,MAAM;AAAA,IAAA,CACnD;AAEK,UAAA,oBAAoB,CAAC,WAAW;AACvB,mBAAA,IAAI,mCAAmC,MAAM;AAAA,IAAA,CAC3D;AAEK,UAAA,EAAE,oBAAoB;AAC5B,UAAM,sBAAsB,mCAAY;AAChC,YAAA,gBAAgB,gBAAgB,KAAK;AAAA,IAAA,GADjB;AAItB,UAAA,EAAE,MAAM;AACR,UAAA,gBAAgB,wBAAC,SAA0C;AACxD,aAAA;AAAA,QACL,GAAG;AAAA,QACH,MAAM,EAAE,qBAAqB,KAAK,EAAE,SAAS,KAAK,IAAI;AAAA,QACtD,SAAS,KAAK,UACV,EAAE,qBAAqB,KAAK,EAAE,YAAY,KAAK,OAAO,IACtD;AAAA,MAAA;AAAA,IACN,GAPoB;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
|
||||||
101
web/assets/ServerStartView-48wfE1MS.js
generated
vendored
Normal file
101
web/assets/ServerStartView-48wfE1MS.js
generated
vendored
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { d as defineComponent, a3 as useI18n, ad as ref, c7 as ProgressStatus, t as onMounted, o as openBlock, J as createBlock, P as withCtx, m as createBaseVNode, aG as createTextVNode, Z as toDisplayString, j as unref, f as createElementBlock, L as createCommentVNode, k as createVNode, l as script, i as withDirectives, v as vShow, c8 as BaseTerminal, p as pushScopeId, q as popScopeId, bV as electronAPI, _ as _export_sfc } from "./index-QvfM__ze.js";
|
||||||
|
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BhQMaVFP.js";
|
||||||
|
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-4140d62b"), n = n(), popScopeId(), n), "_withScopeId");
|
||||||
|
const _hoisted_1 = { class: "flex flex-col w-full h-full items-center" };
|
||||||
|
const _hoisted_2 = { class: "text-2xl font-bold" };
|
||||||
|
const _hoisted_3 = { key: 0 };
|
||||||
|
const _hoisted_4 = {
|
||||||
|
key: 0,
|
||||||
|
class: "flex flex-col items-center gap-4"
|
||||||
|
};
|
||||||
|
const _hoisted_5 = { class: "flex items-center my-4 gap-2" };
|
||||||
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "ServerStartView",
|
||||||
|
setup(__props) {
|
||||||
|
const electron = electronAPI();
|
||||||
|
const { t } = useI18n();
|
||||||
|
const status = ref(ProgressStatus.INITIAL_STATE);
|
||||||
|
const electronVersion = ref("");
|
||||||
|
let xterm;
|
||||||
|
const terminalVisible = ref(true);
|
||||||
|
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
|
||||||
|
status.value = newStatus;
|
||||||
|
if (newStatus === ProgressStatus.ERROR) terminalVisible.value = false;
|
||||||
|
else xterm?.clear();
|
||||||
|
}, "updateProgress");
|
||||||
|
const terminalCreated = /* @__PURE__ */ __name(({ terminal, useAutoSize }, root) => {
|
||||||
|
xterm = terminal;
|
||||||
|
useAutoSize({ root, autoRows: true, autoCols: true });
|
||||||
|
electron.onLogMessage((message) => {
|
||||||
|
terminal.write(message);
|
||||||
|
});
|
||||||
|
terminal.options.cursorBlink = false;
|
||||||
|
terminal.options.disableStdin = true;
|
||||||
|
terminal.options.cursorInactiveStyle = "block";
|
||||||
|
}, "terminalCreated");
|
||||||
|
const reinstall = /* @__PURE__ */ __name(() => electron.reinstall(), "reinstall");
|
||||||
|
const reportIssue = /* @__PURE__ */ __name(() => {
|
||||||
|
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
|
||||||
|
}, "reportIssue");
|
||||||
|
const openLogs = /* @__PURE__ */ __name(() => electron.openLogsFolder(), "openLogs");
|
||||||
|
onMounted(async () => {
|
||||||
|
electron.sendReady();
|
||||||
|
electron.onProgressUpdate(updateProgress);
|
||||||
|
electronVersion.value = await electron.getElectronVersion();
|
||||||
|
});
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
return openBlock(), createBlock(_sfc_main$1, {
|
||||||
|
dark: "",
|
||||||
|
class: "flex-col"
|
||||||
|
}, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createBaseVNode("div", _hoisted_1, [
|
||||||
|
createBaseVNode("h2", _hoisted_2, [
|
||||||
|
createTextVNode(toDisplayString(unref(t)(`serverStart.process.${status.value}`)) + " ", 1),
|
||||||
|
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("span", _hoisted_3, " v" + toDisplayString(electronVersion.value), 1)) : createCommentVNode("", true)
|
||||||
|
]),
|
||||||
|
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("div", _hoisted_4, [
|
||||||
|
createBaseVNode("div", _hoisted_5, [
|
||||||
|
createVNode(unref(script), {
|
||||||
|
icon: "pi pi-flag",
|
||||||
|
severity: "secondary",
|
||||||
|
label: unref(t)("serverStart.reportIssue"),
|
||||||
|
onClick: reportIssue
|
||||||
|
}, null, 8, ["label"]),
|
||||||
|
createVNode(unref(script), {
|
||||||
|
icon: "pi pi-file",
|
||||||
|
severity: "secondary",
|
||||||
|
label: unref(t)("serverStart.openLogs"),
|
||||||
|
onClick: openLogs
|
||||||
|
}, null, 8, ["label"]),
|
||||||
|
createVNode(unref(script), {
|
||||||
|
icon: "pi pi-refresh",
|
||||||
|
label: unref(t)("serverStart.reinstall"),
|
||||||
|
onClick: reinstall
|
||||||
|
}, null, 8, ["label"])
|
||||||
|
]),
|
||||||
|
!terminalVisible.value ? (openBlock(), createBlock(unref(script), {
|
||||||
|
key: 0,
|
||||||
|
icon: "pi pi-search",
|
||||||
|
severity: "secondary",
|
||||||
|
label: unref(t)("serverStart.showTerminal"),
|
||||||
|
onClick: _cache[0] || (_cache[0] = ($event) => terminalVisible.value = true)
|
||||||
|
}, null, 8, ["label"])) : createCommentVNode("", true)
|
||||||
|
])) : createCommentVNode("", true),
|
||||||
|
withDirectives(createVNode(BaseTerminal, { onCreated: terminalCreated }, null, 512), [
|
||||||
|
[vShow, terminalVisible.value]
|
||||||
|
])
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const ServerStartView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-4140d62b"]]);
|
||||||
|
export {
|
||||||
|
ServerStartView as default
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=ServerStartView-48wfE1MS.js.map
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user