Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
619b8cde74 | ||
|
|
31831e6ef1 | ||
|
|
88ceb28e20 | ||
|
|
23289a6a5c | ||
|
|
9d8b6c1f46 | ||
|
|
6320d05696 | ||
|
|
25683b5b02 | ||
|
|
4758fb64b9 | ||
|
|
008761166f | ||
|
|
bfd5dfd611 | ||
|
|
55ade36d01 | ||
|
|
2e20e399ea | ||
|
|
3baf92d120 | ||
|
|
1709a8441e | ||
|
|
cba58fff0b | ||
|
|
2feb8d0b77 | ||
|
|
5b657f8c15 | ||
|
|
2cdbaf5169 | ||
|
|
c78a45685d | ||
|
|
3aaabb12d4 | ||
|
|
1f1c7b7b56 | ||
|
|
90f349f93d | ||
|
|
b9d9bcba14 | ||
|
|
42086af123 | ||
|
|
6c9bd11fa3 | ||
|
|
ee8a7ab69d | ||
|
|
9c773a241b | ||
|
|
adea2beb5c | ||
|
|
2ff3104f70 | ||
|
|
129d8908f7 | ||
|
|
ff838657fa | ||
|
|
2307ff6746 | ||
|
|
d0f3752e33 | ||
|
|
c515bdf371 | ||
|
|
4209edf48d | ||
|
|
d055325783 | ||
|
|
eeab420c70 | ||
|
|
916d1e14a9 | ||
|
|
c496e53519 | ||
|
|
7da85fac3f | ||
|
|
b65b83af6f | ||
|
|
c8a3492c22 | ||
|
|
5cbf79787f | ||
|
|
d45ebb63f6 | ||
|
|
caa6476a69 | ||
|
|
45671cda0b | ||
|
|
8f29664057 | ||
|
|
0b9839ef43 | ||
|
|
953693b137 | ||
|
|
a39ea87bca | ||
|
|
9e9c8a1c64 | ||
|
|
0f11d60afb | ||
|
|
79eea51a1d | ||
|
|
c0338a46a4 | ||
|
|
1c99734e5a | ||
|
|
67758f50f3 | ||
|
|
02eef72bf5 | ||
|
|
b7572b2f87 | ||
|
|
a90aafafc1 | ||
|
|
d9b7cfac7e | ||
|
|
3507870535 | ||
|
|
82ecb02c1e | ||
|
|
a618f768e0 | ||
|
|
e1dec3c792 | ||
|
|
96697c4bc5 | ||
|
|
b504bd606d | ||
|
|
d170292594 | ||
|
|
9cfd185676 | ||
|
|
4b5bcd8ac4 | ||
|
|
ceb50b2cbf | ||
|
|
160ca08138 | ||
|
|
c4bfdba330 |
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -22,7 +22,7 @@ on:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "7"
|
||||
default: "8"
|
||||
|
||||
|
||||
jobs:
|
||||
|
||||
4
.github/workflows/test-build.yml
vendored
4
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
@@ -28,4 +28,4 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements.txt
|
||||
|
||||
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
python-version: '3.9'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
58
.github/workflows/update-frontend.yml
vendored
Normal file
58
.github/workflows/update-frontend.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Update Frontend Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "Frontend version to update to (e.g., 1.0.0)"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install wait-for-it
|
||||
# Frontend asset will be downloaded to ComfyUI/web_custom_versions/Comfy-Org_ComfyUI_frontend/{version}
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu --front-end-version Comfy-Org/ComfyUI_frontend@${{ github.event.inputs.version }} 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||
- name: Configure Git
|
||||
run: |
|
||||
git config --global user.name "GitHub Action"
|
||||
git config --global user.email "action@github.com"
|
||||
# Replace existing frontend content with the new version and remove .js.map files
|
||||
# See https://github.com/Comfy-Org/ComfyUI_frontend/issues/2145 for why we remove .js.map files
|
||||
- name: Update frontend content
|
||||
run: |
|
||||
rm -rf web/
|
||||
cp -r web_custom_versions/Comfy-Org_ComfyUI_frontend/${{ github.event.inputs.version }} web/
|
||||
rm web/**/*.js.map
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.PR_BOT_PAT }}
|
||||
commit-message: "Update frontend to v${{ github.event.inputs.version }}"
|
||||
title: "Frontend Update: v${{ github.event.inputs.version }}"
|
||||
body: |
|
||||
Automated PR to update frontend content to version ${{ github.event.inputs.version }}
|
||||
|
||||
This PR was created automatically by the frontend update workflow.
|
||||
branch: release-${{ github.event.inputs.version }}
|
||||
base: master
|
||||
labels: Frontend,dependencies
|
||||
58
.github/workflows/update-version.yml
vendored
Normal file
58
.github/workflows/update-version.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Update Version File
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "pyproject.toml"
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
runs-on: ubuntu-latest
|
||||
# Don't run on fork PRs
|
||||
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||
permissions:
|
||||
pull-requests: write
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Update comfyui_version.py
|
||||
run: |
|
||||
# Read version from pyproject.toml and update comfyui_version.py
|
||||
python -c '
|
||||
import tomllib
|
||||
|
||||
# Read version from pyproject.toml
|
||||
with open("pyproject.toml", "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
version = config["project"]["version"]
|
||||
|
||||
# Write version to comfyui_version.py
|
||||
with open("comfyui_version.py", "w") as f:
|
||||
f.write("# This file is automatically generated by the build process when version is\n")
|
||||
f.write("# updated in pyproject.toml.\n")
|
||||
f.write(f"__version__ = \"{version}\"\n")
|
||||
'
|
||||
|
||||
- name: Commit changes
|
||||
run: |
|
||||
git config --local user.name "github-actions"
|
||||
git config --local user.email "github-actions@github.com"
|
||||
git fetch origin ${{ github.head_ref }}
|
||||
git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
|
||||
git add comfyui_version.py
|
||||
git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
|
||||
git push origin HEAD:${{ github.head_ref }}
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "7"
|
||||
default: "8"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -7,19 +7,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "124"
|
||||
default: "126"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "4"
|
||||
default: "1"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "7"
|
||||
default: "8"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
# Frontend assets
|
||||
/web/ @huchenlei @webfiltered @pythongosssss
|
||||
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
|
||||
|
||||
# Extra nodes
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||
|
||||
10
README.md
10
README.md
@@ -224,6 +224,16 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
#### Ascend NPUs
|
||||
|
||||
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
|
||||
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
|
||||
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
|
||||
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
||||
@@ -10,4 +10,4 @@ class FileService:
|
||||
if directory_key not in self.allowed_directories:
|
||||
raise ValueError("Invalid directory key")
|
||||
directory_path: str = self.allowed_directories[directory_key]
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
||||
|
||||
@@ -25,10 +25,10 @@ class TerminalService:
|
||||
def update_size(self):
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
changed = True
|
||||
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
@@ -48,9 +48,9 @@ class TerminalService:
|
||||
def send_messages(self, entries):
|
||||
if not len(entries) or not len(self.subscriptions):
|
||||
return
|
||||
|
||||
|
||||
new_size = self.update_size()
|
||||
|
||||
|
||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||
if client_id not in self.server.sockets:
|
||||
# Automatically unsub if the socket has disconnected
|
||||
|
||||
@@ -39,4 +39,4 @@ class FileSystemOperations:
|
||||
"path": relative_path,
|
||||
"type": "directory"
|
||||
})
|
||||
return file_list
|
||||
return file_list
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
from aiohttp import web
|
||||
import logging
|
||||
|
||||
|
||||
class AppSettings():
|
||||
@@ -11,8 +12,12 @@ class AppSettings():
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request, "comfy.settings.json")
|
||||
if os.path.isfile(file):
|
||||
with open(file) as f:
|
||||
return json.load(f)
|
||||
try:
|
||||
with open(file) as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
logging.error(f"The user settings file is corrupted: {file}")
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@@ -51,4 +56,4 @@ class AppSettings():
|
||||
settings = self.get_settings(request)
|
||||
settings[setting_id] = await request.json()
|
||||
self.save_settings(request, settings)
|
||||
return web.Response(status=200)
|
||||
return web.Response(status=200)
|
||||
|
||||
34
app/custom_node_manager.py
Normal file
34
app/custom_node_manager.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
|
||||
class CustomNodeManager:
|
||||
"""
|
||||
Placeholder to refactor the custom node management features from ComfyUI-Manager.
|
||||
Currently it only contains the custom workflow templates feature.
|
||||
"""
|
||||
def add_routes(self, routes, webapp, loadedModules):
|
||||
|
||||
@routes.get("/workflow_templates")
|
||||
async def get_workflow_templates(request):
|
||||
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
||||
files = [
|
||||
file
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes")
|
||||
for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json'))
|
||||
]
|
||||
workflow_templates_dict = {} # custom_nodes folder name -> example workflow names
|
||||
for file in files:
|
||||
custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file)))
|
||||
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
||||
workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name)
|
||||
return web.json_response(workflow_templates_dict)
|
||||
|
||||
# Serve workflow templates from custom nodes.
|
||||
for module_name, module_dir in loadedModules:
|
||||
workflows_dir = os.path.join(module_dir, 'example_workflows')
|
||||
if os.path.exists(workflows_dir):
|
||||
webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)])
|
||||
@@ -51,7 +51,7 @@ def on_flush(callback):
|
||||
if stderr_interceptor is not None:
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
@@ -70,4 +70,15 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
|
||||
if use_stdout:
|
||||
# Only errors and critical to stderr
|
||||
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
|
||||
|
||||
# Lesser to stdout
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
@@ -177,7 +177,7 @@ class ModelFileManager:
|
||||
safetensors_images = json.loads(safetensors_images)
|
||||
for image in safetensors_images:
|
||||
result.append(BytesIO(base64.b64decode(image)))
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
|
||||
@@ -122,7 +122,7 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
@@ -141,6 +141,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
@@ -5,7 +5,7 @@ This module provides type hinting and concrete convenience types for node develo
|
||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||
|
||||
```python
|
||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
@classmethod
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
"""An example node that just adds 1 to an input integer.
|
||||
|
||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
||||
* Not intended for use in ComfyUI.
|
||||
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
|
||||
* This node is intended as an example for developers only.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
|
||||
@@ -120,7 +120,7 @@ class ControlBase:
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
|
||||
@@ -80,7 +80,7 @@ class NoiseScheduleVP:
|
||||
'linear' or 'cosine' for continuous-time DPMs.
|
||||
Returns:
|
||||
A wrapper object of the forward SDE (VP type).
|
||||
|
||||
|
||||
===============================================================
|
||||
|
||||
Example:
|
||||
@@ -208,7 +208,7 @@ def model_wrapper(
|
||||
arXiv preprint arXiv:2202.00512 (2022).
|
||||
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||
arXiv preprint arXiv:2210.02303 (2022).
|
||||
|
||||
|
||||
4. "score": marginal score function. (Trained by denoising score matching).
|
||||
Note that the score function and the noise prediction model follows a simple relationship:
|
||||
```
|
||||
@@ -226,7 +226,7 @@ def model_wrapper(
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
``
|
||||
|
||||
The input `classifier_fn` has the following format:
|
||||
``
|
||||
@@ -240,12 +240,12 @@ def model_wrapper(
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
``
|
||||
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
||||
|
||||
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||
arXiv preprint arXiv:2207.12598 (2022).
|
||||
|
||||
|
||||
|
||||
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||
or continuous-time labels (i.e. epsilon to T).
|
||||
@@ -254,7 +254,7 @@ def model_wrapper(
|
||||
``
|
||||
def model_fn(x, t_continuous) -> noise:
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
return noise_pred(model, x, t_input, **model_kwargs)
|
||||
return noise_pred(model, x, t_input, **model_kwargs)
|
||||
``
|
||||
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
||||
|
||||
@@ -359,7 +359,7 @@ class UniPC:
|
||||
max_val=1.,
|
||||
variant='bh1',
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
"""Construct a UniPC.
|
||||
|
||||
We support both data_prediction and noise_prediction.
|
||||
"""
|
||||
@@ -372,7 +372,7 @@ class UniPC:
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
The dynamic thresholding method.
|
||||
The dynamic thresholding method.
|
||||
"""
|
||||
dims = x0.dim()
|
||||
p = self.dynamic_thresholding_ratio
|
||||
@@ -404,7 +404,7 @@ class UniPC:
|
||||
|
||||
def model_fn(self, x, t):
|
||||
"""
|
||||
Convert the model to the noise prediction model or the data prediction model.
|
||||
Convert the model to the noise prediction model or the data prediction model.
|
||||
"""
|
||||
if self.predict_x0:
|
||||
return self.data_prediction_fn(x, t)
|
||||
@@ -461,7 +461,7 @@ class UniPC:
|
||||
|
||||
def denoise_to_zero_fn(self, x, s):
|
||||
"""
|
||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||
"""
|
||||
return self.data_prediction_fn(x, s)
|
||||
|
||||
@@ -510,7 +510,7 @@ class UniPC:
|
||||
col = torch.ones_like(rks)
|
||||
for k in range(1, K + 1):
|
||||
C.append(col)
|
||||
col = col * rks / (k + 1)
|
||||
col = col * rks / (k + 1)
|
||||
C = torch.stack(C, dim=1)
|
||||
|
||||
if len(D1s) > 0:
|
||||
@@ -621,12 +621,12 @@ class UniPC:
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= (i + 1)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=x.device)
|
||||
@@ -870,4 +870,4 @@ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=F
|
||||
return x
|
||||
|
||||
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
||||
|
||||
446
comfy/hooks.py
446
comfy/hooks.py
@@ -16,130 +16,171 @@ import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
# Hooks explanation
|
||||
# -------------------
|
||||
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
|
||||
# make explicit special cases like it does for ControlNet and GLIGEN.
|
||||
#
|
||||
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
|
||||
# that should run special code when a 'marked' cond is used in sampling.
|
||||
# #######################################################################################################
|
||||
|
||||
class EnumHookMode(enum.Enum):
|
||||
'''
|
||||
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
|
||||
|
||||
MinVram: No caching will occur for any operations related to hooks.
|
||||
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
|
||||
'''
|
||||
MinVram = "minvram"
|
||||
MaxSpeed = "maxspeed"
|
||||
|
||||
class EnumHookType(enum.Enum):
|
||||
'''
|
||||
Hook types, each of which has different expected behavior.
|
||||
'''
|
||||
Weight = "weight"
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Callbacks = "callbacks"
|
||||
Wrappers = "wrappers"
|
||||
SetInjections = "add_injections"
|
||||
AdditionalModels = "add_models"
|
||||
TransformerOptions = "transformer_options"
|
||||
Injections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
Model = "model"
|
||||
Clip = "clip"
|
||||
|
||||
class EnumHookScope(enum.Enum):
|
||||
'''
|
||||
Determines if hook should be limited in its influence over sampling.
|
||||
|
||||
AllConditioning: hook will affect all conds used in sampling.
|
||||
HookedOnly: hook will only affect the conds it was attached to.
|
||||
'''
|
||||
AllConditioning = "all_conditioning"
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
|
||||
# NOTE: this is an example of how the should_register function should look
|
||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
'''Example for how custom_should_register function can look like.'''
|
||||
return True
|
||||
|
||||
|
||||
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
|
||||
'''Creates base dictionary for use with Hooks' target param.'''
|
||||
d = {}
|
||||
if target is not None:
|
||||
d['target'] = target
|
||||
d.update(kwargs)
|
||||
return d
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
||||
self.hook_type = hook_type
|
||||
'''Enum identifying the general class of this hook.'''
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
|
||||
self.hook_id = hook_id
|
||||
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
|
||||
self.hook_scope = hook_scope
|
||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
def initialize_timesteps(self, model: BaseModel):
|
||||
self.reset()
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: Hook = subtype()
|
||||
def clone(self):
|
||||
c: Hook = self.__class__()
|
||||
c.hook_type = self.hook_type
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.hook_scope = self.hook_scope
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return self.custom_should_register(self, model, model_options, target, registered)
|
||||
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: 'Hook'):
|
||||
def __eq__(self, other: Hook):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hook_ref)
|
||||
|
||||
class WeightHook(Hook):
|
||||
'''
|
||||
Hook responsible for tracking weights to be applied to some model/clip.
|
||||
|
||||
Note, value of hook_scope is ignored and is treated as HookedOnly.
|
||||
'''
|
||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||
super().__init__(hook_type=EnumHookType.Weight)
|
||||
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
|
||||
self.weights: dict = None
|
||||
self.weights_clip: dict = None
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
|
||||
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
return self._strength_model * self.strength
|
||||
|
||||
|
||||
@property
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
weights = None
|
||||
if target == EnumWeightTarget.Model:
|
||||
strength = self._strength_model
|
||||
else:
|
||||
|
||||
target = target_dict.get('target', None)
|
||||
if target == EnumWeightTarget.Clip:
|
||||
strength = self._strength_clip
|
||||
|
||||
else:
|
||||
strength = self._strength_model
|
||||
|
||||
if self.need_weight_init:
|
||||
key_map = {}
|
||||
if target == EnumWeightTarget.Model:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Model:
|
||||
weights = self.weights
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
weights = self.weights_clip
|
||||
else:
|
||||
weights = self.weights
|
||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
registered.add(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WeightHook = super().clone(subtype)
|
||||
def clone(self):
|
||||
c: WeightHook = super().clone()
|
||||
c.weights = self.weights
|
||||
c.weights_clip = self.weights_clip
|
||||
c.need_weight_init = self.need_weight_init
|
||||
@@ -147,127 +188,158 @@ class WeightHook(Hook):
|
||||
c._strength_clip = self._strength_clip
|
||||
return c
|
||||
|
||||
class PatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.Patch)
|
||||
self.patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: PatchHook = super().clone(subtype)
|
||||
c.patches = self.patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self):
|
||||
def __init__(self, object_patches: dict[str]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: ObjectPatchHook = super().clone(subtype)
|
||||
self.object_patches = object_patches
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self):
|
||||
c: ObjectPatchHook = super().clone()
|
||||
c.object_patches = self.object_patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class AddModelsHook(Hook):
|
||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
|
||||
|
||||
class AdditionalModelsHook(Hook):
|
||||
'''
|
||||
Hook responsible for telling model management any additional models that should be loaded.
|
||||
|
||||
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
||||
'''
|
||||
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
||||
super().__init__(hook_type=EnumHookType.AdditionalModels)
|
||||
self.models = models
|
||||
self.append_when_same = True
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddModelsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
def clone(self):
|
||||
c: AdditionalModelsHook = super().clone()
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
registered.add(self)
|
||||
return True
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||
class TransformerOptionsHook(Hook):
|
||||
'''
|
||||
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||
'''
|
||||
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||
self.transformers_dict = transformers_dict
|
||||
self.hook_scope = hook_scope
|
||||
self._skip_adding = False
|
||||
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
|
||||
|
||||
def clone(self):
|
||||
c: TransformerOptionsHook = super().clone()
|
||||
c.transformers_dict = self.transformers_dict
|
||||
c._skip_adding = self._skip_adding
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||
self._skip_adding = False
|
||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||
add_model_options = {"transformer_options": self.transformers_dict,
|
||||
"to_load_options": self.transformers_dict}
|
||||
# skip_adding if included in AllConditioning to avoid double loading
|
||||
self._skip_adding = True
|
||||
else:
|
||||
add_model_options = {"to_load_options": self.transformers_dict}
|
||||
registered.add(self)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
return True
|
||||
|
||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
if not self._skip_adding:
|
||||
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||
|
||||
WrapperHook = TransformerOptionsHook
|
||||
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||
|
||||
class InjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.Injections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: SetInjectionsHook = super().clone(subtype)
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self):
|
||||
c: InjectionsHook = super().clone()
|
||||
c.key = self.key
|
||||
c.injections = self.injections.copy() if self.injections else self.injections
|
||||
return c
|
||||
|
||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
# TODO: add functionality
|
||||
pass
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
|
||||
|
||||
class HookGroup:
|
||||
'''
|
||||
Stores groups of hooks, and allows them to be queried by type.
|
||||
|
||||
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
|
||||
always use the provided functions on HookGroup.
|
||||
'''
|
||||
def __init__(self):
|
||||
self.hooks: list[Hook] = []
|
||||
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.hooks)
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
|
||||
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
|
||||
|
||||
def remove(self, hook: Hook):
|
||||
if hook in self.hooks:
|
||||
self.hooks.remove(hook)
|
||||
self._hook_dict[hook.hook_type].remove(hook)
|
||||
|
||||
def get_type(self, hook_type: EnumHookType):
|
||||
return self._hook_dict.get(hook_type, [])
|
||||
|
||||
def contains(self, hook: Hook):
|
||||
return hook in self.hooks
|
||||
|
||||
|
||||
def is_subset_of(self, other: HookGroup):
|
||||
self_hooks = set(self.hooks)
|
||||
other_hooks = set(other.hooks)
|
||||
return self_hooks.issubset(other_hooks)
|
||||
|
||||
def new_with_common_hooks(self, other: HookGroup):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
if other.contains(hook):
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone(self):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone_and_combine(self, other: 'HookGroup'):
|
||||
def clone_and_combine(self, other: HookGroup):
|
||||
c = self.clone()
|
||||
if other is not None:
|
||||
for hook in other.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
@@ -275,36 +347,29 @@ class HookGroup:
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
def get_dict_repr(self):
|
||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||
for hook in self.hooks:
|
||||
with_type = d.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
# only care about WeightHooks, for now
|
||||
for hook in self.get_type(EnumHookType.Weight):
|
||||
hook: WeightHook
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
# hooks should not have their schedules in a list of tuples
|
||||
all_ranges: list[tuple[float, float]] = []
|
||||
for range_kfs in scheduled_hooks.values():
|
||||
@@ -336,7 +401,7 @@ class HookGroup:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
@@ -365,10 +430,16 @@ class HookKeyframe:
|
||||
self.start_percent = float(start_percent)
|
||||
self.start_t = 999999999.9
|
||||
self.guarantee_steps = guarantee_steps
|
||||
|
||||
|
||||
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
|
||||
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
|
||||
if self.start_t > max_sigma:
|
||||
return 0
|
||||
return self.guarantee_steps
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframe(strength=self.strength,
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
c.start_t = self.start_t
|
||||
return c
|
||||
|
||||
@@ -395,7 +466,7 @@ class HookKeyframeGroup:
|
||||
self._current_strength = None
|
||||
self.curr_t = -1.
|
||||
self._set_first_as_current()
|
||||
|
||||
|
||||
def add(self, keyframe: HookKeyframe):
|
||||
# add to end of list, then sort
|
||||
self.keyframes.append(keyframe)
|
||||
@@ -407,33 +478,40 @@ class HookKeyframeGroup:
|
||||
self._current_keyframe = self.keyframes[0]
|
||||
else:
|
||||
self._current_keyframe = None
|
||||
|
||||
|
||||
def has_guarantee_steps(self):
|
||||
for kf in self.keyframes:
|
||||
if kf.guarantee_steps > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def has_index(self, index: int):
|
||||
return index >= 0 and index < len(self.keyframes)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.keyframes) == 0
|
||||
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframeGroup()
|
||||
for keyframe in self.keyframes:
|
||||
c.keyframes.append(keyframe.clone())
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
|
||||
def initialize_timesteps(self, model: BaseModel):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
||||
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
|
||||
if self.is_empty():
|
||||
return False
|
||||
if curr_t == self._curr_t:
|
||||
return False
|
||||
max_sigma = torch.max(transformer_options["sample_sigmas"])
|
||||
prev_index = self._current_index
|
||||
prev_strength = self._current_strength
|
||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
||||
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
|
||||
# if has next index, loop through and see if need to switch
|
||||
if self.has_index(self._current_index+1):
|
||||
for i in range(self._current_index+1, len(self.keyframes)):
|
||||
@@ -446,7 +524,7 @@ class HookKeyframeGroup:
|
||||
self._current_keyframe = eval_c
|
||||
self._current_used_steps = 0
|
||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||
if self._current_keyframe.guarantee_steps > 0:
|
||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
@@ -509,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
||||
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
||||
if hooks is None or model.is_clip:
|
||||
return {}
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
||||
hook: TransformerOptionsHook
|
||||
hook.on_apply_hooks(model, transformer_options)
|
||||
return transformer_options
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
@@ -535,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
|
||||
hook.need_weight_init = False
|
||||
return hook_group
|
||||
|
||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
|
||||
if model is None:
|
||||
return None
|
||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||
@@ -547,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
|
||||
return patches_model
|
||||
|
||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
|
||||
strength_model: float, strength_clip: float):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@@ -565,7 +654,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||
@@ -599,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
|
||||
else:
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
|
||||
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
if cache is None:
|
||||
cache = {}
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
if append_hooks and k == 'hooks':
|
||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||
_combine_hooks_from_values(n[1], values, cache)
|
||||
else:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||
if hooks is None:
|
||||
return cond
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
|
||||
|
||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||
if timestep_range is None:
|
||||
@@ -651,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
|
||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
final_conds = []
|
||||
cache = {}
|
||||
for c in conds:
|
||||
# first, apply lora_hook to conditioning, if provided
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, apply mask to conditioning
|
||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||
# apply timesteps, if present
|
||||
@@ -665,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
cache = {}
|
||||
for c, masked_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, apply mask to new conditioning, if provided
|
||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||
# apply timesteps, if present
|
||||
@@ -679,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
|
||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
cache = {}
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||
new_c = conditioning_set_values(new_c, {'default': True})
|
||||
# apply timesteps, if present
|
||||
|
||||
@@ -70,8 +70,14 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def default_noise_sampler(x):
|
||||
return lambda sigma, sigma_next: torch.randn_like(x)
|
||||
def default_noise_sampler(x, seed=None):
|
||||
if seed is not None:
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
@@ -168,7 +174,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -189,7 +196,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -290,7 +298,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -318,7 +327,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -465,7 +475,7 @@ class DPMSolver(nn.Module):
|
||||
return x_3, eps_cache
|
||||
|
||||
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||
if not t_end > t_start and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
|
||||
@@ -504,7 +514,7 @@ class DPMSolver(nn.Module):
|
||||
return x
|
||||
|
||||
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||
if order not in {2, 3}:
|
||||
raise ValueError('order should be 2 or 3')
|
||||
forward = t_end > t_start
|
||||
@@ -591,7 +601,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
@@ -625,7 +636,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||
@@ -882,7 +894,8 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||
|
||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
@@ -902,7 +915,8 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
@torch.no_grad()
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -1153,7 +1167,8 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
@@ -1179,7 +1194,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
@@ -1230,7 +1246,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
@@ -1249,3 +1265,74 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||
old_uncond_denoised = uncond_denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
phi1_fn = lambda t: torch.expm1(t) / t
|
||||
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||
|
||||
old_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
if cfg_pp:
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
else:
|
||||
gamma = 0
|
||||
sigma_hat = sigmas[i]
|
||||
|
||||
if gamma > 0:
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
|
||||
if sigmas[i + 1] == 0 or old_denoised is None:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
d = to_d(x, sigma_hat, uncond_denoised)
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
else:
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
|
||||
h = t_next - t
|
||||
c2 = (t_prev - t) / h
|
||||
|
||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
|
||||
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
|
||||
|
||||
if cfg_pp:
|
||||
x = x + (denoised - uncond_denoised)
|
||||
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_dimensions = 2
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
taesd_decoder_name = None
|
||||
@@ -143,6 +144,7 @@ class SD3(LatentFormat):
|
||||
|
||||
class StableAudio1(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Flux(SD3):
|
||||
latent_channels = 16
|
||||
@@ -178,6 +180,7 @@ class Flux(SD3):
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
@@ -219,6 +222,8 @@ class Mochi(LatentFormat):
|
||||
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||
@@ -355,6 +360,7 @@ class LTXV(LatentFormat):
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
scale_factor = 0.476986
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
@@ -376,3 +382,28 @@ class HunyuanVideo(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.1817, 0.2284, 0.2423],
|
||||
[-0.0586, -0.0862, -0.3108],
|
||||
[-0.4703, -0.4255, -0.3995],
|
||||
[ 0.0803, 0.1963, 0.1001],
|
||||
[-0.0820, -0.1050, 0.0400],
|
||||
[ 0.2511, 0.3098, 0.2787],
|
||||
[-0.1830, -0.2117, -0.0040],
|
||||
[-0.0621, -0.2187, -0.0939],
|
||||
[ 0.3619, 0.1082, 0.1455],
|
||||
[ 0.3164, 0.3922, 0.2575],
|
||||
[ 0.1152, 0.0231, -0.0462],
|
||||
[-0.1434, -0.3609, -0.3665],
|
||||
[ 0.0635, 0.1471, 0.1680],
|
||||
[-0.3635, -0.1963, -0.3248],
|
||||
[-0.1865, 0.0365, 0.2346],
|
||||
[ 0.0447, 0.0994, 0.0881]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
|
||||
|
||||
@@ -138,7 +138,7 @@ class StageB(nn.Module):
|
||||
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
@@ -148,7 +148,7 @@ class StageB(nn.Module):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
|
||||
@@ -142,7 +142,7 @@ class StageC(nn.Module):
|
||||
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
@@ -152,7 +152,7 @@ class StageC(nn.Module):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
|
||||
808
comfy/ldm/cosmos/blocks.py
Normal file
808
comfy/ldm/cosmos/blocks.py
Normal file
@@ -0,0 +1,808 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
elif name == "R":
|
||||
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||
else:
|
||||
raise ValueError(f"Normalization {name} not found")
|
||||
|
||||
|
||||
class BaseAttentionOp(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Generalized attention impl.
|
||||
|
||||
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
|
||||
If `context_dim` is None, self-attention is assumed.
|
||||
|
||||
Parameters:
|
||||
query_dim (int): Dimension of each query vector.
|
||||
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
|
||||
heads (int, optional): Number of attention heads. Defaults to 8.
|
||||
dim_head (int, optional): Dimension of each head. Defaults to 64.
|
||||
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
|
||||
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
|
||||
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
|
||||
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
|
||||
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
|
||||
Defaults to "SSI".
|
||||
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
|
||||
Defaults to 'per_head'. Only support 'per_head'.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
|
||||
>>> query = torch.randn(10, 128) # Batch size of 10
|
||||
>>> context = torch.randn(10, 256) # Batch size of 10
|
||||
>>> output = attn(query, context) # Perform the attention operation
|
||||
|
||||
Note:
|
||||
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_op: Optional[BaseAttentionOp] = None,
|
||||
qkv_bias: bool = False,
|
||||
out_bias: bool = False,
|
||||
qkv_norm: str = "SSI",
|
||||
qkv_norm_mode: str = "per_head",
|
||||
backend: str = "transformer_engine",
|
||||
qkv_format: str = "bshd",
|
||||
weight_args={},
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.is_selfattn = context_dim is None # self attention
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.qkv_norm_mode = qkv_norm_mode
|
||||
self.qkv_format = qkv_format
|
||||
|
||||
if self.qkv_norm_mode == "per_head":
|
||||
norm_dim = dim_head
|
||||
else:
|
||||
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||
|
||||
self.backend = backend
|
||||
|
||||
self.to_q = nn.Sequential(
|
||||
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[0], norm_dim),
|
||||
)
|
||||
self.to_k = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[1], norm_dim),
|
||||
)
|
||||
self.to_v = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[2], norm_dim),
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def cal_qkv(
|
||||
self, x, context=None, mask=None, rope_emb=None, **kwargs
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
del kwargs
|
||||
|
||||
|
||||
"""
|
||||
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
|
||||
Before 07/24/2024, these modules normalize across all heads.
|
||||
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
|
||||
we support to normalize per head.
|
||||
To keep the checkpoint copatibility with the previous code,
|
||||
we keep the nn.Sequential but call the projection and the normalization layers separately.
|
||||
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
|
||||
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
|
||||
"""
|
||||
if self.qkv_norm_mode == "per_head":
|
||||
q = self.to_q[0](x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k[0](context)
|
||||
v = self.to_v[0](context)
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||
|
||||
q = self.to_q[1](q)
|
||||
k = self.to_k[1](k)
|
||||
v = self.to_v[1](v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
# apply_rotary_pos_emb inlined
|
||||
q_shape = q.shape
|
||||
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
||||
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
||||
|
||||
# apply_rotary_pos_emb inlined
|
||||
k_shape = k.shape
|
||||
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
||||
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
rope_emb=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
Transformer FFN with optional gating
|
||||
|
||||
Parameters:
|
||||
d_model (int): Dimensionality of input features.
|
||||
d_ff (int): Dimensionality of the hidden layer.
|
||||
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
|
||||
activation (callable, optional): The activation function applied after the first linear layer.
|
||||
Defaults to nn.ReLU().
|
||||
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
|
||||
Defaults to False.
|
||||
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
|
||||
|
||||
Example:
|
||||
>>> ff = FeedForward(d_model=512, d_ff=2048)
|
||||
>>> x = torch.randn(64, 10, 512) # Example input tensor
|
||||
>>> output = ff(x)
|
||||
>>> print(output.shape) # Expected shape: (64, 10, 512)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
d_ff: int,
|
||||
dropout: float = 0.1,
|
||||
activation=nn.ReLU(),
|
||||
is_gated: bool = False,
|
||||
bias: bool = False,
|
||||
weight_args={},
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
|
||||
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = activation
|
||||
self.is_gated = is_gated
|
||||
if is_gated:
|
||||
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
g = self.activation(self.layer1(x))
|
||||
if self.is_gated:
|
||||
x = g * self.linear_gate(x)
|
||||
else:
|
||||
x = g
|
||||
assert self.dropout.p == 0.0, "we skip dropout"
|
||||
return self.layer2(x)
|
||||
|
||||
|
||||
class GPT2FeedForward(FeedForward):
|
||||
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
|
||||
super().__init__(
|
||||
d_model=d_model,
|
||||
d_ff=d_ff,
|
||||
dropout=dropout,
|
||||
activation=nn.GELU(),
|
||||
is_gated=False,
|
||||
bias=bias,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
assert self.dropout.p == 0.0, "we skip dropout"
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.activation(x)
|
||||
x = self.layer2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, timesteps):
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - 0.0)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
sin_emb = torch.sin(emb)
|
||||
cos_emb = torch.cos(emb)
|
||||
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||
)
|
||||
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
|
||||
self.activation = nn.SiLU()
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if use_adaln_lora:
|
||||
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
|
||||
else:
|
||||
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear_1(sample)
|
||||
emb = self.activation(emb)
|
||||
emb = self.linear_2(emb)
|
||||
|
||||
if self.use_adaln_lora:
|
||||
adaln_lora_B_3D = emb
|
||||
emb_B_D = sample
|
||||
else:
|
||||
emb_B_D = emb
|
||||
adaln_lora_B_3D = None
|
||||
|
||||
return emb_B_D, adaln_lora_B_3D
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
"""
|
||||
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
|
||||
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
|
||||
|
||||
[B] -> [B, D]
|
||||
|
||||
Parameters:
|
||||
num_channels (int): The number of Fourier features to generate.
|
||||
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
|
||||
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
|
||||
the variance of the features. Defaults to False.
|
||||
|
||||
Example:
|
||||
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
|
||||
>>> x = torch.randn(10, 256) # Example input tensor
|
||||
>>> output = layer(x)
|
||||
>>> print(output.shape) # Expected shape: (10, 256)
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels, bandwidth=1, normalize=False):
|
||||
super().__init__()
|
||||
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||
self.gain = np.sqrt(2) if normalize else 1
|
||||
|
||||
def forward(self, x, gain: float = 1.0):
|
||||
"""
|
||||
Apply the Fourier feature transformation to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed tensor, with Fourier features applied.
|
||||
"""
|
||||
in_dtype = x.dtype
|
||||
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||
and embedding each patch into a vector of size `out_channels`.
|
||||
|
||||
Parameters:
|
||||
- spatial_patch_size (int): The size of each spatial patch.
|
||||
- temporal_patch_size (int): The size of each temporal patch.
|
||||
- in_channels (int): Number of input channels. Default: 3.
|
||||
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spatial_patch_size,
|
||||
temporal_patch_size,
|
||||
in_channels=3,
|
||||
out_channels=768,
|
||||
bias=True,
|
||||
weight_args={},
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
Rearrange(
|
||||
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||
r=temporal_patch_size,
|
||||
m=spatial_patch_size,
|
||||
n=spatial_patch_size,
|
||||
),
|
||||
operations.Linear(
|
||||
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
|
||||
),
|
||||
)
|
||||
self.out = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the PatchEmbed module.
|
||||
|
||||
Parameters:
|
||||
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||
B is the batch size,
|
||||
C is the number of channels,
|
||||
T is the temporal dimension,
|
||||
H is the height, and
|
||||
W is the width of the input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||
"""
|
||||
assert x.dim() == 5
|
||||
_, _, T, H, W = x.shape
|
||||
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||
assert T % self.temporal_patch_size == 0
|
||||
x = self.proj(x)
|
||||
return self.out(x)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of video DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
spatial_patch_size,
|
||||
temporal_patch_size,
|
||||
out_channels,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
weight_args={},
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.n_adaln_chunks = 2
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if use_adaln_lora:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
|
||||
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
|
||||
)
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_BT_HW_D,
|
||||
emb_B_D,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_3D is not None
|
||||
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
|
||||
2, dim=1
|
||||
)
|
||||
else:
|
||||
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
|
||||
|
||||
B = emb_B_D.shape[0]
|
||||
T = x_BT_HW_D.shape[0] // B
|
||||
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
|
||||
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
|
||||
|
||||
x_BT_HW_D = self.linear(x_BT_HW_D)
|
||||
return x_BT_HW_D
|
||||
|
||||
|
||||
class VideoAttn(nn.Module):
|
||||
"""
|
||||
Implements video attention with optional cross-attention capabilities.
|
||||
|
||||
This module processes video features while maintaining their spatio-temporal structure. It can perform
|
||||
self-attention within the video features or cross-attention with external context features.
|
||||
|
||||
Parameters:
|
||||
x_dim (int): Dimension of input feature vectors
|
||||
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
|
||||
num_heads (int): Number of attention heads
|
||||
bias (bool): Whether to include bias in attention projections. Default: False
|
||||
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
|
||||
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
|
||||
|
||||
Input shape:
|
||||
- x: (T, H, W, B, D) video features
|
||||
- context (optional): (M, B, D) context features for cross-attention
|
||||
where:
|
||||
T: temporal dimension
|
||||
H: height
|
||||
W: width
|
||||
B: batch size
|
||||
D: feature dimension
|
||||
M: context sequence length
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_dim: int,
|
||||
context_dim: Optional[int],
|
||||
num_heads: int,
|
||||
bias: bool = False,
|
||||
qkv_norm_mode: str = "per_head",
|
||||
x_format: str = "BTHWD",
|
||||
weight_args={},
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.x_format = x_format
|
||||
|
||||
self.attn = Attention(
|
||||
x_dim,
|
||||
context_dim,
|
||||
num_heads,
|
||||
x_dim // num_heads,
|
||||
qkv_bias=bias,
|
||||
qkv_norm="RRI",
|
||||
out_bias=bias,
|
||||
qkv_norm_mode=qkv_norm_mode,
|
||||
qkv_format="sbhd",
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for video attention.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
|
||||
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
|
||||
where M is the sequence length of the context.
|
||||
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
|
||||
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor with applied attention, maintaining the input shape.
|
||||
"""
|
||||
|
||||
x_T_H_W_B_D = x
|
||||
context_M_B_D = context
|
||||
T, H, W, B, D = x_T_H_W_B_D.shape
|
||||
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
|
||||
x_THW_B_D = self.attn(
|
||||
x_THW_B_D,
|
||||
context_M_B_D,
|
||||
crossattn_mask,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
)
|
||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||
return x_T_H_W_B_D
|
||||
|
||||
|
||||
def adaln_norm_state(norm_state, x, scale, shift):
|
||||
normalized = norm_state(x)
|
||||
return normalized * (1 + scale) + shift
|
||||
|
||||
|
||||
class DITBuildingBlock(nn.Module):
|
||||
"""
|
||||
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
|
||||
attention and MLP operations with adaptive layer normalization.
|
||||
|
||||
Parameters:
|
||||
block_type (str): Type of block - one of:
|
||||
- "cross_attn"/"ca": Cross-attention
|
||||
- "full_attn"/"fa": Full self-attention
|
||||
- "mlp"/"ff": MLP/feedforward block
|
||||
x_dim (int): Dimension of input features
|
||||
context_dim (Optional[int]): Dimension of context features for cross-attention
|
||||
num_heads (int): Number of attention heads
|
||||
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||
bias (bool): Whether to use bias in layers. Default: False
|
||||
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
|
||||
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
|
||||
x_format (str): Input tensor format. Default: "BTHWD"
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_type: str,
|
||||
x_dim: int,
|
||||
context_dim: Optional[int],
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
bias: bool = False,
|
||||
mlp_dropout: float = 0.0,
|
||||
qkv_norm_mode: str = "per_head",
|
||||
x_format: str = "BTHWD",
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
weight_args={},
|
||||
operations=None
|
||||
) -> None:
|
||||
block_type = block_type.lower()
|
||||
|
||||
super().__init__()
|
||||
self.x_format = x_format
|
||||
if block_type in ["cross_attn", "ca"]:
|
||||
self.block = VideoAttn(
|
||||
x_dim,
|
||||
context_dim,
|
||||
num_heads,
|
||||
bias=bias,
|
||||
qkv_norm_mode=qkv_norm_mode,
|
||||
x_format=self.x_format,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
elif block_type in ["full_attn", "fa"]:
|
||||
self.block = VideoAttn(
|
||||
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
|
||||
)
|
||||
elif block_type in ["mlp", "ff"]:
|
||||
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {block_type}")
|
||||
|
||||
self.block_type = block_type
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
|
||||
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.n_adaln_chunks = 3
|
||||
if use_adaln_lora:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
|
||||
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
|
||||
)
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb_B_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
|
||||
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
|
||||
crossattn_emb (Tensor): Tensor for cross-attention blocks.
|
||||
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
|
||||
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor after processing through the configured block and adaptive normalization.
|
||||
"""
|
||||
if self.use_adaln_lora:
|
||||
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
|
||||
self.n_adaln_chunks, dim=1
|
||||
)
|
||||
else:
|
||||
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
|
||||
|
||||
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
|
||||
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||
)
|
||||
|
||||
if self.block_type in ["mlp", "ff"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
)
|
||||
elif self.block_type in ["full_attn", "fa"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=None,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
)
|
||||
elif self.block_type in ["cross_attn", "ca"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GeneralDITTransformerBlock(nn.Module):
|
||||
"""
|
||||
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
|
||||
Each block in the sequence is specified by a block configuration string.
|
||||
|
||||
Parameters:
|
||||
x_dim (int): Dimension of input features
|
||||
context_dim (int): Dimension of context features for cross-attention blocks
|
||||
num_heads (int): Number of attention heads
|
||||
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
|
||||
full-attention, then MLP)
|
||||
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||
x_format (str): Input tensor format. Default: "BTHWD"
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||
|
||||
The block_config string uses "-" to separate block types:
|
||||
- "ca"/"cross_attn": Cross-attention block
|
||||
- "fa"/"full_attn": Full self-attention block
|
||||
- "mlp"/"ff": MLP/feedforward block
|
||||
|
||||
Example:
|
||||
block_config = "ca-fa-mlp" creates a sequence of:
|
||||
1. Cross-attention block
|
||||
2. Full self-attention block
|
||||
3. MLP block
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_dim: int,
|
||||
context_dim: int,
|
||||
num_heads: int,
|
||||
block_config: str,
|
||||
mlp_ratio: float = 4.0,
|
||||
x_format: str = "BTHWD",
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
weight_args={},
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList()
|
||||
self.x_format = x_format
|
||||
for block_type in block_config.split("-"):
|
||||
self.blocks.append(
|
||||
DITBuildingBlock(
|
||||
block_type,
|
||||
x_dim,
|
||||
context_dim,
|
||||
num_heads,
|
||||
mlp_ratio,
|
||||
x_format=self.x_format,
|
||||
use_adaln_lora=use_adaln_lora,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb_B_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
x,
|
||||
emb_B_D,
|
||||
crossattn_emb,
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
)
|
||||
return x
|
||||
1041
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
1041
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
File diff suppressed because it is too large
Load Diff
377
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
377
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The patcher and unpatcher implementation for 2D and 3D data.
|
||||
|
||||
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
|
||||
One on the rows and one on the columns.
|
||||
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
|
||||
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
|
||||
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
|
||||
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
|
||||
as we need to support downsampling for more than 2x.
|
||||
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
|
||||
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
_WAVELETS = {
|
||||
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
|
||||
"rearrange": torch.tensor([1.0, 1.0]),
|
||||
}
|
||||
_PERSISTENT = False
|
||||
|
||||
|
||||
class Patcher(torch.nn.Module):
|
||||
"""A module to convert image tensors into patches using torch operations.
|
||||
|
||||
The main difference from `class Patching` is that this module implements
|
||||
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||
|
||||
It's bit-wise identical to the Patching module outputs, with the added
|
||||
benefit of being torch.jit scriptable.
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=1, patch_method="haar"):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.patch_method = patch_method
|
||||
self.register_buffer(
|
||||
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||
)
|
||||
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||
self.register_buffer(
|
||||
"_arange",
|
||||
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||
persistent=_PERSISTENT,
|
||||
)
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.patch_method == "haar":
|
||||
return self._haar(x)
|
||||
elif self.patch_method == "rearrange":
|
||||
return self._arrange(x)
|
||||
else:
|
||||
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||
|
||||
def _dwt(self, x, mode="reflect", rescale=False):
|
||||
dtype = x.dtype
|
||||
h = self.wavelets.to(device=x.device)
|
||||
|
||||
n = h.shape[0]
|
||||
g = x.shape[1]
|
||||
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = hh.to(dtype=dtype)
|
||||
hl = hl.to(dtype=dtype)
|
||||
|
||||
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
||||
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
|
||||
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
|
||||
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||
|
||||
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
|
||||
if rescale:
|
||||
out = out / 2
|
||||
return out
|
||||
|
||||
def _haar(self, x):
|
||||
for _ in self.range:
|
||||
x = self._dwt(x, rescale=True)
|
||||
return x
|
||||
|
||||
def _arrange(self, x):
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (h p1) (w p2) -> b (c p1 p2) h w",
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
).contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class Patcher3D(Patcher):
|
||||
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
|
||||
|
||||
def __init__(self, patch_size=1, patch_method="haar"):
|
||||
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||
self.register_buffer(
|
||||
"patch_size_buffer",
|
||||
patch_size * torch.ones([1], dtype=torch.int32),
|
||||
persistent=_PERSISTENT,
|
||||
)
|
||||
|
||||
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
|
||||
dtype = x.dtype
|
||||
h = self.wavelets.to(device=x.device)
|
||||
|
||||
n = h.shape[0]
|
||||
g = x.shape[1]
|
||||
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = hh.to(dtype=dtype)
|
||||
hl = hl.to(dtype=dtype)
|
||||
|
||||
# Handles temporal axis.
|
||||
x = F.pad(
|
||||
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
|
||||
).to(dtype)
|
||||
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||
|
||||
# Handles spatial axes.
|
||||
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||
|
||||
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||
|
||||
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
|
||||
if rescale:
|
||||
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
|
||||
return out
|
||||
|
||||
def _haar(self, x):
|
||||
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||
for _ in self.range:
|
||||
x = self._dwt(x, "haar", rescale=True)
|
||||
return x
|
||||
|
||||
def _arrange(self, x):
|
||||
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
p3=self.patch_size,
|
||||
).contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class UnPatcher(torch.nn.Module):
|
||||
"""A module to convert patches into image tensorsusing torch operations.
|
||||
|
||||
The main difference from `class Unpatching` is that this module implements
|
||||
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||
|
||||
It's bit-wise identical to the Unpatching module outputs, with the added
|
||||
benefit of being torch.jit scriptable.
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=1, patch_method="haar"):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.patch_method = patch_method
|
||||
self.register_buffer(
|
||||
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||
)
|
||||
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||
self.register_buffer(
|
||||
"_arange",
|
||||
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||
persistent=_PERSISTENT,
|
||||
)
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.patch_method == "haar":
|
||||
return self._ihaar(x)
|
||||
elif self.patch_method == "rearrange":
|
||||
return self._iarrange(x)
|
||||
else:
|
||||
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||
|
||||
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||
dtype = x.dtype
|
||||
h = self.wavelets.to(device=x.device)
|
||||
n = h.shape[0]
|
||||
|
||||
g = x.shape[1] // 4
|
||||
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hh = hh.to(dtype=dtype)
|
||||
hl = hl.to(dtype=dtype)
|
||||
|
||||
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
|
||||
|
||||
# Inverse transform.
|
||||
yl = torch.nn.functional.conv_transpose2d(
|
||||
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||
)
|
||||
yl += torch.nn.functional.conv_transpose2d(
|
||||
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||
)
|
||||
yh = torch.nn.functional.conv_transpose2d(
|
||||
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||
)
|
||||
yh += torch.nn.functional.conv_transpose2d(
|
||||
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||
)
|
||||
y = torch.nn.functional.conv_transpose2d(
|
||||
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||
)
|
||||
y += torch.nn.functional.conv_transpose2d(
|
||||
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||
)
|
||||
|
||||
if rescale:
|
||||
y = y * 2
|
||||
return y
|
||||
|
||||
def _ihaar(self, x):
|
||||
for _ in self.range:
|
||||
x = self._idwt(x, "haar", rescale=True)
|
||||
return x
|
||||
|
||||
def _iarrange(self, x):
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class UnPatcher3D(UnPatcher):
|
||||
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
|
||||
|
||||
def __init__(self, patch_size=1, patch_method="haar"):
|
||||
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||
|
||||
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||
dtype = x.dtype
|
||||
h = self.wavelets.to(device=x.device)
|
||||
|
||||
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
|
||||
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||
hl = hl.to(dtype=dtype)
|
||||
hh = hh.to(dtype=dtype)
|
||||
|
||||
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||
del x
|
||||
|
||||
# Height height transposed convolutions.
|
||||
xll = F.conv_transpose3d(
|
||||
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlll
|
||||
|
||||
xll += F.conv_transpose3d(
|
||||
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xllh
|
||||
|
||||
xlh = F.conv_transpose3d(
|
||||
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlhl
|
||||
|
||||
xlh += F.conv_transpose3d(
|
||||
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlhh
|
||||
|
||||
xhl = F.conv_transpose3d(
|
||||
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhll
|
||||
|
||||
xhl += F.conv_transpose3d(
|
||||
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhlh
|
||||
|
||||
xhh = F.conv_transpose3d(
|
||||
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhhl
|
||||
|
||||
xhh += F.conv_transpose3d(
|
||||
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhhh
|
||||
|
||||
# Handles width transposed convolutions.
|
||||
xl = F.conv_transpose3d(
|
||||
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xll
|
||||
|
||||
xl += F.conv_transpose3d(
|
||||
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xlh
|
||||
|
||||
xh = F.conv_transpose3d(
|
||||
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xhl
|
||||
|
||||
xh += F.conv_transpose3d(
|
||||
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xhh
|
||||
|
||||
# Handles time axis transposed convolutions.
|
||||
x = F.conv_transpose3d(
|
||||
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||
)
|
||||
del xl
|
||||
|
||||
x += F.conv_transpose3d(
|
||||
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||
)
|
||||
|
||||
if rescale:
|
||||
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
|
||||
return x
|
||||
|
||||
def _ihaar(self, x):
|
||||
for _ in self.range:
|
||||
x = self._idwt(x, "haar", rescale=True)
|
||||
x = x[:, :, self.patch_size - 1 :, ...]
|
||||
return x
|
||||
|
||||
def _iarrange(self, x):
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
p3=self.patch_size,
|
||||
)
|
||||
x = x[:, :, self.patch_size - 1 :, ...]
|
||||
return x
|
||||
112
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
112
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared utilities for the networks module."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||
batch_size = x.shape[0]
|
||||
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
|
||||
|
||||
|
||||
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
||||
|
||||
|
||||
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||
batch_size, height = x.shape[0], x.shape[-2]
|
||||
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
|
||||
|
||||
|
||||
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
|
||||
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
|
||||
|
||||
|
||||
def cast_tuple(t: Any, length: int = 1) -> Any:
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def replication_pad(x):
|
||||
return torch.cat([x[:, :, :1, ...], x], dim=2)
|
||||
|
||||
|
||||
def divisible_by(num: int, den: int) -> bool:
|
||||
return (num % den) == 0
|
||||
|
||||
|
||||
def is_odd(n: int) -> bool:
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return ops.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class CausalNormalize(torch.nn.Module):
|
||||
def __init__(self, in_channels, num_groups=1):
|
||||
super().__init__()
|
||||
self.norm = ops.GroupNorm(
|
||||
num_groups=num_groups,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
)
|
||||
self.num_groups = num_groups
|
||||
|
||||
def forward(self, x):
|
||||
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
|
||||
# All new models should use num_groups=1, otherwise causality is not guaranteed.
|
||||
if self.num_groups == 1:
|
||||
x, batch_size = time2batch(x)
|
||||
return batch2time(self.norm(x), batch_size)
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def default(*args):
|
||||
for arg in args:
|
||||
if exists(arg):
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def round_ste(z: torch.Tensor) -> torch.Tensor:
|
||||
"""Round with straight through gradients."""
|
||||
zhat = z.round()
|
||||
return z + (zhat - z).detach()
|
||||
|
||||
|
||||
def log(t, eps=1e-5):
|
||||
return t.clamp(min=eps).log()
|
||||
|
||||
|
||||
def entropy(prob):
|
||||
return (-prob * log(prob)).sum(dim=-1)
|
||||
514
comfy/ldm/cosmos/model.py
Normal file
514
comfy/ldm/cosmos/model.py
Normal file
@@ -0,0 +1,514 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
PatchEmbed,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
|
||||
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
|
||||
|
||||
class GeneralDIT(nn.Module):
|
||||
"""
|
||||
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||
|
||||
Args:
|
||||
max_img_h (int): Maximum height of the input images.
|
||||
max_img_w (int): Maximum width of the input images.
|
||||
max_frames (int): Maximum number of frames in the video sequence.
|
||||
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||
out_channels (int): Number of output channels.
|
||||
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||
block_config (str): Configuration of the transformer block. See Notes for supported block types.
|
||||
model_channels (int): Base number of channels used throughout the model.
|
||||
num_blocks (int): Number of transformer blocks.
|
||||
num_heads (int): Number of heads in the multi-head attention layers.
|
||||
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
|
||||
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
|
||||
pos_emb_cls (str): Type of positional embeddings.
|
||||
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||
affline_emb_norm (bool): Whether to normalize affine embeddings.
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
|
||||
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||
|
||||
Notes:
|
||||
Supported block types in block_config:
|
||||
* cross_attn, ca: Cross attention
|
||||
* full_attn: Full attention on all flattened tokens
|
||||
* mlp, ff: Feed forward block
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_img_h: int,
|
||||
max_img_w: int,
|
||||
max_frames: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
patch_spatial: tuple,
|
||||
patch_temporal: int,
|
||||
concat_padding_mask: bool = True,
|
||||
# attention settings
|
||||
block_config: str = "FA-CA-MLP",
|
||||
model_channels: int = 768,
|
||||
num_blocks: int = 10,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
block_x_format: str = "BTHWD",
|
||||
# cross attention settings
|
||||
crossattn_emb_channels: int = 1024,
|
||||
use_cross_attn_mask: bool = False,
|
||||
# positional embedding settings
|
||||
pos_emb_cls: str = "sincos",
|
||||
pos_emb_learnable: bool = False,
|
||||
pos_emb_interpolation: str = "crop",
|
||||
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
rope_h_extrapolation_ratio: float = 1.0,
|
||||
rope_w_extrapolation_ratio: float = 1.0,
|
||||
rope_t_extrapolation_ratio: float = 1.0,
|
||||
extra_per_block_abs_pos_emb: bool = False,
|
||||
extra_per_block_abs_pos_emb_type: str = "sincos",
|
||||
extra_h_extrapolation_ratio: float = 1.0,
|
||||
extra_w_extrapolation_ratio: float = 1.0,
|
||||
extra_t_extrapolation_ratio: float = 1.0,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_img_h = max_img_h
|
||||
self.max_img_w = max_img_w
|
||||
self.max_frames = max_frames
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_spatial = patch_spatial
|
||||
self.patch_temporal = patch_temporal
|
||||
self.num_heads = num_heads
|
||||
self.num_blocks = num_blocks
|
||||
self.model_channels = model_channels
|
||||
self.use_cross_attn_mask = use_cross_attn_mask
|
||||
self.concat_padding_mask = concat_padding_mask
|
||||
# positional embedding settings
|
||||
self.pos_emb_cls = pos_emb_cls
|
||||
self.pos_emb_learnable = pos_emb_learnable
|
||||
self.pos_emb_interpolation = pos_emb_interpolation
|
||||
self.affline_emb_norm = affline_emb_norm
|
||||
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
|
||||
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||
self.dtype = dtype
|
||||
weight_args = {"device": device, "dtype": dtype}
|
||||
|
||||
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||
self.x_embedder = PatchEmbed(
|
||||
spatial_patch_size=patch_spatial,
|
||||
temporal_patch_size=patch_temporal,
|
||||
in_channels=in_channels,
|
||||
out_channels=model_channels,
|
||||
bias=False,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.build_pos_embed(device=device, dtype=dtype)
|
||||
self.block_x_format = block_x_format
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
self.t_embedder = nn.ModuleList(
|
||||
[Timesteps(model_channels),
|
||||
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleDict()
|
||||
|
||||
for idx in range(num_blocks):
|
||||
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
|
||||
x_dim=model_channels,
|
||||
context_dim=crossattn_emb_channels,
|
||||
num_heads=num_heads,
|
||||
block_config=block_config,
|
||||
mlp_ratio=mlp_ratio,
|
||||
x_format=self.block_x_format,
|
||||
use_adaln_lora=use_adaln_lora,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
if self.affline_emb_norm:
|
||||
logging.debug("Building affine embedding normalization layer")
|
||||
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
||||
else:
|
||||
self.affline_norm = nn.Identity()
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=self.model_channels,
|
||||
spatial_patch_size=self.patch_spatial,
|
||||
temporal_patch_size=self.patch_temporal,
|
||||
out_channels=self.out_channels,
|
||||
use_adaln_lora=self.use_adaln_lora,
|
||||
adaln_lora_dim=self.adaln_lora_dim,
|
||||
weight_args=weight_args,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def build_pos_embed(self, device=None, dtype=None):
|
||||
if self.pos_emb_cls == "rope3d":
|
||||
cls_type = VideoRopePosition3DEmb
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||
|
||||
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
kwargs = dict(
|
||||
model_channels=self.model_channels,
|
||||
len_h=self.max_img_h // self.patch_spatial,
|
||||
len_w=self.max_img_w // self.patch_spatial,
|
||||
len_t=self.max_frames // self.patch_temporal,
|
||||
is_learnable=self.pos_emb_learnable,
|
||||
interpolation=self.pos_emb_interpolation,
|
||||
head_dim=self.model_channels // self.num_heads,
|
||||
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||
device=device,
|
||||
)
|
||||
self.pos_embedder = cls_type(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
assert self.extra_per_block_abs_pos_emb_type in [
|
||||
"learnable",
|
||||
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
|
||||
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||
kwargs["device"] = device
|
||||
kwargs["dtype"] = dtype
|
||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def prepare_embedded_sequence(
|
||||
self,
|
||||
x_B_C_T_H_W: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||
|
||||
Args:
|
||||
x_B_C_T_H_W (torch.Tensor): video
|
||||
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||
If None, a default value (`self.base_fps`) will be used.
|
||||
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||
|
||||
Notes:
|
||||
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||
the `self.pos_embedder` with the shape [T, H, W].
|
||||
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||
`self.pos_embedder` with the fps tensor.
|
||||
- Otherwise, the positional embeddings are generated without considering fps.
|
||||
"""
|
||||
if self.concat_padding_mask:
|
||||
if padding_mask is not None:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
else:
|
||||
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||
|
||||
x_B_C_T_H_W = torch.cat(
|
||||
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||
)
|
||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)
|
||||
else:
|
||||
extra_pos_emb = None
|
||||
|
||||
if "rope" in self.pos_emb_cls.lower():
|
||||
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||
|
||||
if "fps_aware" in self.pos_emb_cls:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||
else:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||
|
||||
return x_B_T_H_W_D, None, extra_pos_emb
|
||||
|
||||
def decoder_head(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
del crossattn_emb, crossattn_mask
|
||||
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
|
||||
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
|
||||
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
|
||||
# This is to ensure x_BT_HW_D has the correct shape because
|
||||
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
|
||||
x_BT_HW_D = x_BT_HW_D.view(
|
||||
B * T_before_patchify // self.patch_temporal,
|
||||
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
|
||||
-1,
|
||||
)
|
||||
x_B_D_T_H_W = rearrange(
|
||||
x_BT_HW_D,
|
||||
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||
p1=self.patch_spatial,
|
||||
p2=self.patch_spatial,
|
||||
H=H_before_patchify // self.patch_spatial,
|
||||
W=W_before_patchify // self.patch_spatial,
|
||||
t=self.patch_temporal,
|
||||
B=B,
|
||||
)
|
||||
return x_B_D_T_H_W
|
||||
|
||||
def forward_before_blocks(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||
timesteps: (B, ) tensor of timesteps
|
||||
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||
"""
|
||||
del kwargs
|
||||
assert isinstance(
|
||||
data_type, DataType
|
||||
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
|
||||
original_shape = x.shape
|
||||
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||
x,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
latent_condition=latent_condition,
|
||||
latent_condition_sigma=latent_condition_sigma,
|
||||
)
|
||||
# logging affline scale information
|
||||
affline_scale_log_info = {}
|
||||
|
||||
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
|
||||
affline_emb_B_D = timesteps_B_D
|
||||
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
|
||||
|
||||
if scalar_feature is not None:
|
||||
raise NotImplementedError("Scalar feature is not implemented yet.")
|
||||
|
||||
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
|
||||
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
|
||||
|
||||
if self.use_cross_attn_mask:
|
||||
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
|
||||
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
|
||||
else:
|
||||
crossattn_mask = None
|
||||
|
||||
if self.blocks["block0"].x_format == "THWBD":
|
||||
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
|
||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
|
||||
)
|
||||
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
|
||||
|
||||
if crossattn_mask:
|
||||
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
|
||||
|
||||
elif self.blocks["block0"].x_format == "BTHWD":
|
||||
x = x_B_T_H_W_D
|
||||
else:
|
||||
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
|
||||
output = {
|
||||
"x": x,
|
||||
"affline_emb_B_D": affline_emb_B_D,
|
||||
"crossattn_emb": crossattn_emb,
|
||||
"crossattn_mask": crossattn_mask,
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
|
||||
"adaln_lora_B_3D": adaln_lora_B_3D,
|
||||
"original_shape": original_shape,
|
||||
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
}
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# crossattn_emb: torch.Tensor,
|
||||
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||
timesteps: (B, ) tensor of timesteps
|
||||
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
|
||||
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
|
||||
we need forward_before_blocks pass to the forward_before_blocks function.
|
||||
"""
|
||||
|
||||
crossattn_emb = context
|
||||
crossattn_mask = attention_mask
|
||||
|
||||
inputs = self.forward_before_blocks(
|
||||
x=x,
|
||||
timesteps=timesteps,
|
||||
crossattn_emb=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
fps=fps,
|
||||
image_size=image_size,
|
||||
padding_mask=padding_mask,
|
||||
scalar_feature=scalar_feature,
|
||||
data_type=data_type,
|
||||
latent_condition=latent_condition,
|
||||
latent_condition_sigma=latent_condition_sigma,
|
||||
condition_video_augment_sigma=condition_video_augment_sigma,
|
||||
**kwargs,
|
||||
)
|
||||
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
|
||||
inputs["x"],
|
||||
inputs["affline_emb_B_D"],
|
||||
inputs["crossattn_emb"],
|
||||
inputs["crossattn_mask"],
|
||||
inputs["rope_emb_L_1_1_D"],
|
||||
inputs["adaln_lora_B_3D"],
|
||||
inputs["original_shape"],
|
||||
)
|
||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||
del inputs
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
assert (
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||
|
||||
for _, block in self.blocks.items():
|
||||
assert (
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
||||
x = block(
|
||||
x,
|
||||
affline_emb_B_D,
|
||||
crossattn_emb,
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
||||
x_B_D_T_H_W = self.decoder_head(
|
||||
x_B_T_H_W_D=x_B_T_H_W_D,
|
||||
emb_B_D=affline_emb_B_D,
|
||||
crossattn_emb=None,
|
||||
origin_shape=original_shape,
|
||||
crossattn_mask=None,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
)
|
||||
|
||||
return x_B_D_T_H_W
|
||||
208
comfy/ldm/cosmos/position_embedding.py
Normal file
208
comfy/ldm/cosmos/position_embedding.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
|
||||
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
||||
"""
|
||||
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to normalize.
|
||||
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
||||
eps (float, optional): A small constant to ensure numerical stability during division.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
if dim is None:
|
||||
dim = list(range(1, x.ndim))
|
||||
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||
return x / norm.to(x.dtype)
|
||||
|
||||
|
||||
class VideoPositionEmb(nn.Module):
|
||||
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||
"""
|
||||
It delegates the embedding generation to generate_embeddings function.
|
||||
"""
|
||||
B_T_H_W_C = x_B_T_H_W_C.shape
|
||||
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
|
||||
|
||||
return embeddings
|
||||
|
||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
def __init__(
|
||||
self,
|
||||
*, # enforce keyword arguments
|
||||
head_dim: int,
|
||||
len_h: int,
|
||||
len_w: int,
|
||||
len_t: int,
|
||||
base_fps: int = 24,
|
||||
h_extrapolation_ratio: float = 1.0,
|
||||
w_extrapolation_ratio: float = 1.0,
|
||||
t_extrapolation_ratio: float = 1.0,
|
||||
device=None,
|
||||
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||
):
|
||||
del kwargs
|
||||
super().__init__()
|
||||
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
||||
self.base_fps = base_fps
|
||||
self.max_h = len_h
|
||||
self.max_w = len_w
|
||||
|
||||
dim = head_dim
|
||||
dim_h = dim // 6 * 2
|
||||
dim_w = dim_h
|
||||
dim_t = dim - 2 * dim_h
|
||||
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
||||
self.register_buffer(
|
||||
"dim_spatial_range",
|
||||
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"dim_temporal_range",
|
||||
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
||||
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
||||
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
||||
|
||||
def generate_embeddings(
|
||||
self,
|
||||
B_T_H_W_C: torch.Size,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
h_ntk_factor: Optional[float] = None,
|
||||
w_ntk_factor: Optional[float] = None,
|
||||
t_ntk_factor: Optional[float] = None,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Generate embeddings for the given input size.
|
||||
|
||||
Args:
|
||||
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
||||
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
||||
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
||||
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
||||
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
||||
|
||||
Returns:
|
||||
Not specified in the original code snippet.
|
||||
"""
|
||||
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
||||
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
||||
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
||||
|
||||
h_theta = 10000.0 * h_ntk_factor
|
||||
w_theta = 10000.0 * w_ntk_factor
|
||||
t_theta = 10000.0 * t_ntk_factor
|
||||
|
||||
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
||||
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
||||
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||
|
||||
B, T, H, W, _ = B_T_H_W_C
|
||||
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||
assert (
|
||||
uniform_fps or B == 1 or T == 1
|
||||
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||
assert (
|
||||
H <= self.max_h and W <= self.max_w
|
||||
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
||||
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
||||
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||
|
||||
# apply sequence scaling in temporal dimension
|
||||
if fps is None: # image case
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||
else:
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
|
||||
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
||||
|
||||
em_T_H_W_D = torch.cat(
|
||||
[
|
||||
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
||||
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
||||
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
||||
]
|
||||
, dim=-2,
|
||||
)
|
||||
|
||||
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
||||
|
||||
|
||||
class LearnablePosEmbAxis(VideoPositionEmb):
|
||||
def __init__(
|
||||
self,
|
||||
*, # enforce keyword arguments
|
||||
interpolation: str,
|
||||
model_channels: int,
|
||||
len_h: int,
|
||||
len_w: int,
|
||||
len_t: int,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
||||
"""
|
||||
del kwargs # unused
|
||||
super().__init__()
|
||||
self.interpolation = interpolation
|
||||
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||
|
||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||
|
||||
|
||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||
B, T, H, W, _ = B_T_H_W_C
|
||||
if self.interpolation == "crop":
|
||||
emb_h_H = self.pos_emb_h[:H].to(device=device)
|
||||
emb_w_W = self.pos_emb_w[:W].to(device=device)
|
||||
emb_t_T = self.pos_emb_t[:T].to(device=device)
|
||||
emb = (
|
||||
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
||||
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
||||
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
||||
)
|
||||
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
||||
else:
|
||||
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
||||
|
||||
return normalize(emb, dim=-1, eps=1e-6)
|
||||
124
comfy/ldm/cosmos/vae.py
Normal file
124
comfy/ldm/cosmos/vae.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
from enum import Enum
|
||||
|
||||
from .cosmos_tokenizer.layers3d import (
|
||||
EncoderFactorized,
|
||||
DecoderFactorized,
|
||||
CausalConv3d,
|
||||
)
|
||||
|
||||
|
||||
class IdentityDistribution(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, parameters):
|
||||
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
|
||||
|
||||
|
||||
class GaussianDistribution(torch.nn.Module):
|
||||
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
|
||||
super().__init__()
|
||||
self.min_logvar = min_logvar
|
||||
self.max_logvar = max_logvar
|
||||
|
||||
def sample(self, mean, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
|
||||
def forward(self, parameters):
|
||||
mean, logvar = torch.chunk(parameters, 2, dim=1)
|
||||
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
|
||||
return self.sample(mean, logvar), (mean, logvar)
|
||||
|
||||
|
||||
class ContinuousFormulation(Enum):
|
||||
VAE = GaussianDistribution
|
||||
AE = IdentityDistribution
|
||||
|
||||
|
||||
class CausalContinuousVideoTokenizer(nn.Module):
|
||||
def __init__(
|
||||
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
|
||||
self.latent_channels = latent_channels
|
||||
self.sigma_data = 0.5
|
||||
|
||||
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
|
||||
self.encoder = EncoderFactorized(
|
||||
z_channels=z_factor * z_channels, **kwargs
|
||||
)
|
||||
if kwargs.get("temporal_compression", 4) == 4:
|
||||
kwargs["channels_mult"] = [2, 4]
|
||||
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
|
||||
self.decoder = DecoderFactorized(
|
||||
z_channels=z_channels, **kwargs
|
||||
)
|
||||
|
||||
self.quant_conv = CausalConv3d(
|
||||
z_factor * z_channels,
|
||||
z_factor * latent_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
self.post_quant_conv = CausalConv3d(
|
||||
latent_channels, z_channels, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
|
||||
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
||||
|
||||
num_parameters = sum(param.numel() for param in self.parameters())
|
||||
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||
logging.debug(
|
||||
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
||||
)
|
||||
|
||||
latent_temporal_chunk = 16
|
||||
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
z, posteriors = self.distribution(moments)
|
||||
latent_ch = z.shape[1]
|
||||
latent_t = z.shape[2]
|
||||
dtype = z.dtype
|
||||
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
|
||||
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
|
||||
return ((z - mean) / std) * self.sigma_data
|
||||
|
||||
def decode(self, z):
|
||||
in_dtype = z.dtype
|
||||
latent_ch = z.shape[1]
|
||||
latent_t = z.shape[2]
|
||||
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||
|
||||
z = z / self.sigma_data
|
||||
z = z * std + mean
|
||||
z = self.post_quant_conv(z)
|
||||
return self.decoder(z)
|
||||
|
||||
@@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
@@ -5,8 +5,15 @@ from torch import Tensor
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
|
||||
@@ -168,7 +168,7 @@ class Flux(nn.Module):
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
|
||||
@@ -159,7 +159,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
||||
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
||||
v = v.transpose(-2, -3).contiguous()
|
||||
v = v.transpose(-2, -3).contiguous()
|
||||
|
||||
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
|
||||
|
||||
@@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ class FeedForward(nn.Module):
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -142,16 +142,23 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
if skip_output_reshape:
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
)
|
||||
else:
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
if skip_output_reshape:
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads))
|
||||
else:
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -326,12 +335,18 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
del q, k, v
|
||||
|
||||
r1 = (
|
||||
r1.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if skip_output_reshape:
|
||||
r1 = (
|
||||
r1.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
)
|
||||
else:
|
||||
r1 = (
|
||||
r1.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return r1
|
||||
|
||||
BROKEN_XFORMERS = False
|
||||
@@ -342,7 +357,7 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
@@ -395,9 +410,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
@@ -408,7 +426,7 @@ else:
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@@ -429,9 +447,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||
for i in range(0, b, SDP_BATCH_LIMIT):
|
||||
@@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
return out
|
||||
|
||||
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
@@ -473,11 +492,15 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if tensor_layout == "HND":
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
if skip_output_reshape:
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
|
||||
return out
|
||||
|
||||
|
||||
def vae_attention():
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
return xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
return pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
return normal_attention
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
@@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
|
||||
@@ -17,10 +17,10 @@ import math
|
||||
import logging
|
||||
|
||||
try:
|
||||
from typing import Optional, NamedTuple, List, Protocol
|
||||
from typing import Optional, NamedTuple, List, Protocol
|
||||
except ImportError:
|
||||
from typing import Optional, NamedTuple, List
|
||||
from typing_extensions import Protocol
|
||||
from typing import Optional, NamedTuple, List
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from typing import List
|
||||
|
||||
@@ -261,7 +261,7 @@ def efficient_dot_product_attention(
|
||||
value=value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
|
||||
@@ -223,7 +223,7 @@ class PixArtMS(nn.Module):
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
|
||||
@@ -194,4 +194,4 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
||||
return loss
|
||||
|
||||
@@ -33,6 +33,7 @@ import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -188,9 +189,10 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if denoise_mask is not None:
|
||||
if len(denoise_mask.shape) == len(noise.shape):
|
||||
denoise_mask = denoise_mask[:,:1]
|
||||
denoise_mask = denoise_mask[:, :1]
|
||||
|
||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||
num_dim = noise.ndim - 2
|
||||
denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:]))
|
||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||
@@ -200,12 +202,16 @@ class BaseModel(torch.nn.Module):
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask.to(device))
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
|
||||
elif ck == "mask_inverted":
|
||||
cond_concat.append(1.0 - denoise_mask.to(device))
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
cond_concat.append(torch.ones_like(noise)[:, :1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
elif ck == "mask_inverted":
|
||||
cond_concat.append(torch.zeros_like(noise)[:, :1])
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
return data
|
||||
return None
|
||||
@@ -293,6 +299,9 @@ class BaseModel(torch.nn.Module):
|
||||
return blank_image
|
||||
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
@@ -787,7 +796,7 @@ class Flux(BaseModel):
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
# upscale the attention mask, since now we
|
||||
# upscale the attention mask, since now we
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
shape = kwargs["noise"].shape
|
||||
@@ -856,3 +865,30 @@ class HunyuanVideo(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
|
||||
return out
|
||||
|
||||
class CosmosVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
|
||||
self.image_to_video = image_to_video
|
||||
if self.image_to_video:
|
||||
self.concat_keys = ("mask_inverted",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
|
||||
sigma_noise_augmentation = 0 #TODO
|
||||
if sigma_noise_augmentation != 0:
|
||||
latent_image = latent_image + noise
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||
|
||||
@@ -229,7 +229,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
if pe_key in state_dict_keys:
|
||||
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
|
||||
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
|
||||
|
||||
|
||||
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
|
||||
if ar_key in state_dict_keys:
|
||||
dit_config["image_model"] = "pixart_alpha"
|
||||
@@ -239,6 +239,51 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["micro_condition"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos"
|
||||
dit_config["max_img_h"] = 240
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["block_config"] = "FA-CA-MLP"
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
dit_config["pos_emb_learnable"] = False
|
||||
dit_config["pos_emb_interpolation"] = "crop"
|
||||
dit_config["block_x_format"] = "THWBD"
|
||||
dit_config["affline_emb_norm"] = True
|
||||
dit_config["use_adaln_lora"] = True
|
||||
dit_config["adaln_lora_dim"] = 256
|
||||
|
||||
if dit_config["model_channels"] == 4096:
|
||||
# 7B
|
||||
dit_config["num_blocks"] = 28
|
||||
dit_config["num_heads"] = 32
|
||||
dit_config["extra_per_block_abs_pos_emb"] = True
|
||||
dit_config["rope_h_extrapolation_ratio"] = 1.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 1.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
else: # 5120
|
||||
# 14B
|
||||
dit_config["num_blocks"] = 36
|
||||
dit_config["num_heads"] = 40
|
||||
dit_config["extra_per_block_abs_pos_emb"] = True
|
||||
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
||||
dit_config["extra_h_extrapolation_ratio"] = 2.0
|
||||
dit_config["extra_w_extrapolation_ratio"] = 2.0
|
||||
dit_config["extra_t_extrapolation_ratio"] = 2.0
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -393,6 +438,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
"model.model.", #audio models
|
||||
"net.", #cosmos
|
||||
]
|
||||
counts = {k: 0 for k in candidates}
|
||||
for k in state_dict:
|
||||
@@ -571,12 +617,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
|
||||
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||
|
||||
@@ -86,6 +86,13 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
_ = torch.npu.device_count()
|
||||
npu_available = torch.npu.is_available()
|
||||
except:
|
||||
npu_available = False
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
@@ -97,6 +104,12 @@ def is_intel_xpu():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_ascend_npu():
|
||||
global npu_available
|
||||
if npu_available:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@@ -110,6 +123,8 @@ def get_torch_device():
|
||||
else:
|
||||
if is_intel_xpu():
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
elif is_ascend_npu():
|
||||
return torch.device("npu", torch.npu.current_device())
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
@@ -130,6 +145,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
elif is_ascend_npu():
|
||||
stats = torch.npu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
_, mem_total_npu = torch.npu.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_npu
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@@ -209,7 +230,7 @@ try:
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if is_intel_xpu():
|
||||
if is_intel_xpu() or is_ascend_npu():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
@@ -274,6 +295,8 @@ def get_torch_device_name(device):
|
||||
return "{}".format(device.type)
|
||||
elif is_intel_xpu():
|
||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||
elif is_ascend_npu():
|
||||
return "{} {}".format(device, torch.npu.get_device_name(device))
|
||||
else:
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
@@ -860,6 +883,8 @@ def xformers_enabled():
|
||||
return False
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
if is_ascend_npu():
|
||||
return False
|
||||
if directml_enabled:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
@@ -884,6 +909,8 @@ def pytorch_attention_flash_attention():
|
||||
return True
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
return False
|
||||
|
||||
def mac_version():
|
||||
@@ -923,6 +950,13 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_total = mem_free_xpu + mem_free_torch
|
||||
elif is_ascend_npu():
|
||||
stats = torch.npu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_npu, _ = torch.npu.mem_get_info(dev)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_npu + mem_free_torch
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@@ -984,6 +1018,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
@@ -1081,19 +1118,16 @@ def soft_empty_cache(force=False):
|
||||
torch.mps.empty_cache()
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_ascend_npu():
|
||||
torch.npu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||
logging.warning("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
||||
return weight
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
@@ -141,7 +141,7 @@ class AutoPatcherEjector:
|
||||
self.was_injected = False
|
||||
self.prev_skip_injection = False
|
||||
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.was_injected = False
|
||||
self.prev_skip_injection = self.model.skip_injection
|
||||
@@ -164,7 +164,7 @@ class MemoryCounter:
|
||||
self.value = initial
|
||||
self.minimum = minimum
|
||||
# TODO: add a safe limit besides 0
|
||||
|
||||
|
||||
def use(self, weight: torch.Tensor):
|
||||
weight_size = weight.nelement() * weight.element_size()
|
||||
if self.is_useable(weight_size):
|
||||
@@ -210,7 +210,7 @@ class ModelPatcher:
|
||||
self.injections: dict[str, list[PatcherInjection]] = {}
|
||||
|
||||
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
||||
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
|
||||
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
|
||||
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
||||
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
||||
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
||||
@@ -282,7 +282,7 @@ class ModelPatcher:
|
||||
n.injections[k] = i.copy()
|
||||
# hooks
|
||||
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
||||
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
|
||||
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
|
||||
for group in self.cached_hook_patches:
|
||||
n.cached_hook_patches[group] = {}
|
||||
for k in self.cached_hook_patches[group]:
|
||||
@@ -402,7 +402,20 @@ class ModelPatcher:
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def get_model_object(self, name):
|
||||
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||
"""Retrieves a nested attribute from an object using dot notation considering
|
||||
object patches.
|
||||
|
||||
Args:
|
||||
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
||||
|
||||
Returns:
|
||||
The value of the requested attribute
|
||||
|
||||
Example:
|
||||
patcher = ModelPatcher()
|
||||
weight = patcher.get_model_object("layer1.conv.weight")
|
||||
"""
|
||||
if name in self.object_patches:
|
||||
return self.object_patches[name]
|
||||
else:
|
||||
@@ -711,7 +724,7 @@ class ModelPatcher:
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
self.backup.pop(key)
|
||||
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
@@ -789,7 +802,7 @@ class ModelPatcher:
|
||||
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
||||
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||
c.append(callback)
|
||||
|
||||
|
||||
def remove_callbacks_with_key(self, call_type: str, key: str):
|
||||
c = self.callbacks.get(call_type, {})
|
||||
if key in c:
|
||||
@@ -797,7 +810,7 @@ class ModelPatcher:
|
||||
|
||||
def get_callbacks(self, call_type: str, key: str):
|
||||
return self.callbacks.get(call_type, {}).get(key, [])
|
||||
|
||||
|
||||
def get_all_callbacks(self, call_type: str):
|
||||
c_list = []
|
||||
for c in self.callbacks.get(call_type, {}).values():
|
||||
@@ -810,7 +823,7 @@ class ModelPatcher:
|
||||
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
|
||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||
w = self.wrappers.get(wrapper_type, {})
|
||||
if key in w:
|
||||
@@ -831,7 +844,7 @@ class ModelPatcher:
|
||||
def remove_attachments(self, key: str):
|
||||
if key in self.attachments:
|
||||
self.attachments.pop(key)
|
||||
|
||||
|
||||
def get_attachment(self, key: str):
|
||||
return self.attachments.get(key, None)
|
||||
|
||||
@@ -842,6 +855,9 @@ class ModelPatcher:
|
||||
if key in self.injections:
|
||||
self.injections.pop(key)
|
||||
|
||||
def get_injections(self, key: str):
|
||||
return self.injections.get(key, None)
|
||||
|
||||
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
||||
self.additional_models[key] = models
|
||||
|
||||
@@ -851,7 +867,7 @@ class ModelPatcher:
|
||||
|
||||
def get_additional_models_with_key(self, key: str):
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models = []
|
||||
for models in self.additional_models.values():
|
||||
@@ -906,24 +922,25 @@ class ModelPatcher:
|
||||
self.model.current_patcher = self
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
|
||||
def prepare_state(self, timestep):
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if len(self.hook_patches_backup) > 0:
|
||||
if self.hook_patches_backup is not None:
|
||||
self.hook_patches = self.hook_patches_backup
|
||||
self.hook_patches_backup = {}
|
||||
self.hook_patches_backup = None
|
||||
|
||||
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||
self.hook_mode = hook_mode
|
||||
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
|
||||
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
@@ -939,25 +956,26 @@ class ModelPatcher:
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
self.restore_hook_patches()
|
||||
registered_hooks: list[comfy.hooks.Hook] = []
|
||||
# handle WrapperHooks, if model_options provided
|
||||
if model_options is not None:
|
||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
|
||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
||||
if registered is None:
|
||||
registered = comfy.hooks.HookGroup()
|
||||
# handle WeightHooks
|
||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
|
||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
||||
if hook.hook_ref not in self.hook_patches:
|
||||
weight_hooks_to_register.append(hook)
|
||||
else:
|
||||
registered.add(hook)
|
||||
if len(weight_hooks_to_register) > 0:
|
||||
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
||||
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
||||
for hook in weight_hooks_to_register:
|
||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
||||
hook.add_hook_patches(self, model_options, target_dict, registered)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||
callback(self, hooks_dict, target)
|
||||
callback(self, hooks, target_dict, model_options, registered)
|
||||
return registered
|
||||
|
||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||
with self.use_ejected():
|
||||
@@ -975,7 +993,7 @@ class ModelPatcher:
|
||||
key = k[0]
|
||||
if len(k) > 2:
|
||||
function = k[2]
|
||||
|
||||
|
||||
if key in model_sd:
|
||||
p.add(k)
|
||||
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
||||
@@ -1008,11 +1026,11 @@ class ModelPatcher:
|
||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
||||
# TODO: return transformer_options dict with any additions from hooks
|
||||
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||
return {}
|
||||
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||
self.patch_hooks(hooks=hooks)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||
callback(self, hooks)
|
||||
return {}
|
||||
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
@@ -1063,7 +1081,7 @@ class ModelPatcher:
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||
if key not in combined_patches:
|
||||
return
|
||||
|
||||
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
weight: torch.Tensor
|
||||
if key not in self.hook_backup:
|
||||
@@ -1098,7 +1116,7 @@ class ModelPatcher:
|
||||
del temp_weight
|
||||
del out_weight
|
||||
del weight
|
||||
|
||||
|
||||
def unpatch_hooks(self) -> None:
|
||||
with self.use_ejected():
|
||||
if len(self.hook_backup) == 0:
|
||||
@@ -1107,7 +1125,7 @@ class ModelPatcher:
|
||||
keys = list(self.hook_backup.keys())
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
|
||||
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
|
||||
|
||||
@@ -96,12 +96,12 @@ class WrapperExecutor:
|
||||
self.wrappers = wrappers.copy()
|
||||
self.idx = idx
|
||||
self.is_last = idx == len(wrappers)
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls the next wrapper or original function, whichever is appropriate."""
|
||||
new_executor = self._create_next_executor()
|
||||
return new_executor.execute(*args, **kwargs)
|
||||
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
||||
args = list(args)
|
||||
@@ -121,7 +121,7 @@ class WrapperExecutor:
|
||||
@classmethod
|
||||
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
||||
|
||||
|
||||
@classmethod
|
||||
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj, wrappers, idx=idx)
|
||||
|
||||
@@ -13,7 +13,7 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
generator = torch.manual_seed(seed)
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
@@ -25,9 +25,11 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels
|
||||
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1)
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
|
||||
latent_image = latent_image.unsqueeze(2)
|
||||
return latent_image
|
||||
|
||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||
|
||||
@@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
||||
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
|
||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||
cnets: list[ControlBase] = []
|
||||
for c in cond:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook: comfy.hooks.Hook
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
full_hooks.add(hook)
|
||||
if 'control' in c:
|
||||
cnets.append(c['control'])
|
||||
|
||||
@@ -42,7 +40,7 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
|
||||
if cnet.previous_controlnet is None:
|
||||
return _list
|
||||
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
|
||||
|
||||
|
||||
hooks_list = []
|
||||
cnets = set(cnets)
|
||||
for base_cnet in cnets:
|
||||
@@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
|
||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||
if extra_hooks is not None:
|
||||
for hook in extra_hooks.hooks:
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
full_hooks.add(hook)
|
||||
|
||||
return hooks_dict
|
||||
return full_hooks
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
@@ -73,13 +70,11 @@ def get_additional_models(conds, dtype):
|
||||
cnets: list[ControlBase] = []
|
||||
gligen = []
|
||||
add_models = []
|
||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
||||
|
||||
for k in conds:
|
||||
cnets += get_models_from_cond(conds[k], "control")
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
@@ -90,11 +85,20 @@ def get_additional_models(conds, dtype):
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = [x[1] for x in gligen]
|
||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
||||
models = control_models + gligen + add_models + hook_models
|
||||
models = control_models + gligen + add_models
|
||||
|
||||
return models, inference_memory
|
||||
|
||||
def get_additional_models_from_model_options(model_options: dict[str]=None):
|
||||
"""loads additional models from registered AddModels hooks"""
|
||||
models = []
|
||||
if model_options is not None and "registered_hooks" in model_options:
|
||||
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
|
||||
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||
hook: comfy.hooks.AdditionalModelsHook
|
||||
models.extend(hook.models)
|
||||
return models
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
for m in models:
|
||||
@@ -102,9 +106,10 @@ def cleanup_additional_models(models):
|
||||
m.cleanup()
|
||||
|
||||
|
||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||
real_model: 'BaseModel' = None
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
@@ -123,12 +128,35 @@ def cleanup_models(conds, models):
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
'''
|
||||
Registers hooks from conds.
|
||||
'''
|
||||
# check for hooks in conds - if not registered, see if can be applied
|
||||
hooks = {}
|
||||
hooks = comfy.hooks.HookGroup()
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
# register hooks on model/model_options
|
||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
||||
# begin registering hooks
|
||||
registered = comfy.hooks.HookGroup()
|
||||
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||
# handle all TransformerOptionsHooks
|
||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
||||
hook: comfy.hooks.TransformerOptionsHook
|
||||
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||
# handle all AddModelsHooks
|
||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||
hook: comfy.hooks.AdditionalModelsHook
|
||||
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||
# handle all WeightHooks by registering on ModelPatcher
|
||||
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
|
||||
# add registered_hooks onto model_options for further reference
|
||||
if len(registered) > 0:
|
||||
model_options["registered_hooks"] = registered
|
||||
# merge original wrappers and callbacks with hooked wrappers and callbacks
|
||||
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
|
||||
for wc_name in ["wrappers", "callbacks"]:
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
@@ -144,7 +145,7 @@ def cond_cat(c_list):
|
||||
|
||||
return out
|
||||
|
||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
|
||||
# need to figure out remaining unmasked area for conds
|
||||
default_mults = []
|
||||
for _ in default_conds:
|
||||
@@ -183,7 +184,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
||||
# replace p's mult with calculated mult
|
||||
p = p._replace(mult=mult)
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
@@ -218,13 +219,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
|
||||
@@ -375,7 +376,7 @@ class KSamplerX0Inpaint:
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
@@ -467,6 +468,13 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||
|
||||
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
|
||||
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
|
||||
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
|
||||
sigmas = adj_idxs.new_zeros(n + 1)
|
||||
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
|
||||
return sigmas
|
||||
|
||||
def get_mask_aabb(masks):
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||
@@ -679,7 +687,7 @@ class Sampler:
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis"]
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
@@ -802,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
||||
for cond in conds_to_modify:
|
||||
cond['hooks'] = hooks
|
||||
|
||||
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
|
||||
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
|
||||
HookGroups that have the same reference.'''
|
||||
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
|
||||
# if None were registered, make sure all hooks are cleaned from conds
|
||||
if registered is None:
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
kk.pop('hooks', None)
|
||||
return
|
||||
# find conds that contain hooks to be replaced - group by common HookGroup refs
|
||||
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||
if hooks is not None:
|
||||
if not hooks.is_subset_of(registered):
|
||||
to_replace = hook_replacement.setdefault(hooks, [])
|
||||
to_replace.append(kk)
|
||||
# for each hook to replace, create a new proper HookGroup and assign to all common conds
|
||||
for hooks, conds_to_modify in hook_replacement.items():
|
||||
new_hooks = hooks.new_with_common_hooks(registered)
|
||||
if len(new_hooks) == 0:
|
||||
new_hooks = None
|
||||
for kk in conds_to_modify:
|
||||
kk['hooks'] = new_hooks
|
||||
|
||||
|
||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
hooks_set = set()
|
||||
@@ -811,9 +846,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
return len(hooks_set)
|
||||
|
||||
|
||||
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
'''
|
||||
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
|
||||
'''
|
||||
if model_options is None:
|
||||
return
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
if dtype is not None:
|
||||
casts.append(dtype)
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
for cast in casts:
|
||||
patch_list[k] = patch_list[k].to(cast)
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
if hasattr(wc_list[i], "to"):
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
self.model_options = model_patcher.model_options
|
||||
self.original_conds = {}
|
||||
self.cfg = 1.0
|
||||
@@ -840,7 +924,9 @@ class CFGGuider:
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
||||
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||
extra_args = {"model_options": extra_model_options, "seed": seed}
|
||||
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
sampler.sample,
|
||||
@@ -851,7 +937,7 @@ class CFGGuider:
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
@@ -860,6 +946,7 @@ class CFGGuider:
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
@@ -889,6 +976,7 @@ class CFGGuider:
|
||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||
filter_registered_hooks_on_conds(self.conds, self.model_options)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.outer_sample,
|
||||
self,
|
||||
@@ -896,6 +984,7 @@ class CFGGuider:
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||
self.model_options = orig_model_options
|
||||
self.model_patcher.hook_mode = orig_hook_mode
|
||||
self.model_patcher.restore_hook_patches()
|
||||
@@ -911,29 +1000,37 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = normal_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "simple":
|
||||
sigmas = simple_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
elif scheduler_name == "beta":
|
||||
sigmas = beta_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "linear_quadratic":
|
||||
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
||||
else:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
return sigmas
|
||||
class SchedulerHandler(NamedTuple):
|
||||
handler: Callable[..., torch.Tensor]
|
||||
# Boolean indicates whether to call the handler like:
|
||||
# scheduler_function(model_sampling, steps) or
|
||||
# scheduler_function(n, sigma_min: float, sigma_max: float)
|
||||
use_ms: bool = True
|
||||
|
||||
SCHEDULER_HANDLERS = {
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
||||
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||
"simple": SchedulerHandler(simple_scheduler),
|
||||
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
||||
"beta": SchedulerHandler(beta_scheduler),
|
||||
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||
}
|
||||
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
|
||||
|
||||
def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
|
||||
handler = SCHEDULER_HANDLERS.get(scheduler_name)
|
||||
if handler is None:
|
||||
err = f"error invalid scheduler {scheduler_name}"
|
||||
logging.error(err)
|
||||
raise ValueError(err)
|
||||
if handler.use_ms:
|
||||
return handler.handler(model_sampling, steps)
|
||||
return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
|
||||
def sampler_object(name):
|
||||
if name == "uni_pc":
|
||||
|
||||
31
comfy/sd.py
31
comfy/sd.py
@@ -11,6 +11,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import yaml
|
||||
import math
|
||||
|
||||
@@ -34,6 +35,7 @@ import comfy.text_encoders.long_clipl
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -111,7 +113,7 @@ class CLIP:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
self.use_clip_schedule = False
|
||||
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@@ -376,6 +378,19 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||
self.upscale_index_formula = (8, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 8, 8)
|
||||
self.downscale_index_formula = (8, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
|
||||
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
|
||||
#TODO: these values are a bit off because this is not a standard VAE
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -519,7 +534,7 @@ class VAE:
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3:
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
@@ -641,6 +656,7 @@ class CLIPType(Enum):
|
||||
LTXV = 8
|
||||
HUNYUAN_VIDEO = 9
|
||||
PIXART = 10
|
||||
COSMOS = 11
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -658,6 +674,7 @@ class TEModel(Enum):
|
||||
T5_XL = 5
|
||||
T5_BASE = 6
|
||||
LLAMA3_8 = 7
|
||||
T5_XXL_OLD = 8
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -672,6 +689,8 @@ def detect_te_model(sd):
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[-1] == 2048:
|
||||
return TEModel.T5_XL
|
||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||
return TEModel.T5_XXL_OLD
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
return TEModel.T5_BASE
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
@@ -681,9 +700,10 @@ def detect_te_model(sd):
|
||||
|
||||
def t5xxl_detect(clip_data):
|
||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||
weight_name_old = "encoder.block.23.layer.1.DenseReluDense.wi.weight"
|
||||
|
||||
for sd in clip_data:
|
||||
if weight_name in sd:
|
||||
if weight_name in sd or weight_name_old in sd:
|
||||
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||
|
||||
return {}
|
||||
@@ -740,6 +760,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
elif te_model == TEModel.T5_XXL_OLD:
|
||||
clip_target.clip = comfy.text_encoders.cosmos.te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.cosmos.CosmosT5Tokenizer
|
||||
elif te_model == TEModel.T5_XL:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
@@ -898,7 +921,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded straight to GPU")
|
||||
logging.info("loaded diffusion model directly to GPU")
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
@@ -388,13 +388,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
import safetensors.torch
|
||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||
else:
|
||||
if 'weights_only' in torch.load.__code__.co_varnames:
|
||||
try:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
except:
|
||||
embed_out = safe_load_embed_zip(embed_path)
|
||||
else:
|
||||
embed = torch.load(embed_path, map_location="cpu")
|
||||
try:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
except:
|
||||
embed_out = safe_load_embed_zip(embed_path)
|
||||
except Exception:
|
||||
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||||
return None
|
||||
|
||||
@@ -14,6 +14,7 @@ import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -608,6 +609,8 @@ class PixArtAlpha(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
memory_usage_factor = 0.5
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@@ -640,6 +643,8 @@ class HunyuanDiT(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
memory_usage_factor = 1.3
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@@ -783,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.HunyuanVideo
|
||||
|
||||
memory_usage_factor = 2.0 #TODO
|
||||
memory_usage_factor = 1.8 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
@@ -819,6 +824,47 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
|
||||
class CosmosT2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cosmos",
|
||||
"in_channels": 16,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"sigma_data": 0.5,
|
||||
"sigma_max": 80.0,
|
||||
"sigma_min": 0.002,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Cosmos1CV8x8x8
|
||||
|
||||
memory_usage_factor = 1.6 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosVideo(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||
|
||||
class CosmosI2V(CosmosT2V):
|
||||
unet_config = {
|
||||
"image_model": "cosmos",
|
||||
"in_channels": 17,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
42
comfy/text_encoders/cosmos.py
Normal file
42
comfy/text_encoders/cosmos.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||
|
||||
class CosmosT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
|
||||
|
||||
|
||||
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class CosmosTEModel_(CosmosT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return CosmosTEModel_
|
||||
@@ -227,8 +227,9 @@ class T5(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
model_dim = config_dict["d_model"]
|
||||
inner_dim = config_dict["d_kv"] * config_dict["num_heads"]
|
||||
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, inner_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
|
||||
|
||||
|
||||
22
comfy/text_encoders/t5_old_config_xxl.json
Normal file
22
comfy/text_encoders/t5_old_config_xxl.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 65536,
|
||||
"d_kv": 128,
|
||||
"d_model": 1024,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "relu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": false,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 128,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 32128
|
||||
}
|
||||
@@ -29,17 +29,29 @@ import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
class ModelCheckpoint:
|
||||
pass
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
from numpy.core.multiarray import scalar
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
from _codecs import encode
|
||||
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
safe_load = False
|
||||
if safe_load:
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
@@ -455,7 +467,7 @@ def pixart_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
for k in PIXART_MAP_BASIC:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
|
||||
return key_map
|
||||
|
||||
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||
@@ -693,7 +705,25 @@ def copy_to_param(obj, attr, value):
|
||||
prev = getattr(obj, attrs[-1])
|
||||
prev.data.copy_(value)
|
||||
|
||||
def get_attr(obj, attr):
|
||||
def get_attr(obj, attr: str):
|
||||
"""Retrieves a nested attribute from an object using dot notation.
|
||||
|
||||
Args:
|
||||
obj: The object to get the attribute from
|
||||
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
||||
|
||||
Returns:
|
||||
The value of the requested attribute
|
||||
|
||||
Example:
|
||||
model = MyModel()
|
||||
weight = get_attr(model, "layer1.conv.weight")
|
||||
# Equivalent to: model.layer1.conv.weight
|
||||
|
||||
Important:
|
||||
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
|
||||
accessing nested model objects under `ModelPatcher.model`.
|
||||
"""
|
||||
attrs = attr.split(".")
|
||||
for name in attrs:
|
||||
obj = getattr(obj, name)
|
||||
@@ -702,7 +732,7 @@ def get_attr(obj, attr):
|
||||
def bislerp(samples, width, height):
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
|
||||
|
||||
c = b1.shape[-1]
|
||||
|
||||
#norms
|
||||
@@ -727,16 +757,16 @@ def bislerp(samples, width, height):
|
||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||
|
||||
#edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
|
||||
def generate_bilinear_data(length_old, length_new, device):
|
||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
|
||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
@@ -747,7 +777,7 @@ def bislerp(samples, width, height):
|
||||
samples = samples.float()
|
||||
n,c,h,w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
|
||||
#linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
@@ -893,7 +923,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
for it in itertools.product(*positions):
|
||||
s_in = s
|
||||
|
||||
@@ -54,8 +54,8 @@ class DynamicPrompt:
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
def get_input_info(class_def, input_name):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
def get_input_info(class_def, input_name, valid_inputs=None):
|
||||
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
||||
input_info = None
|
||||
input_category = None
|
||||
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||
@@ -131,7 +131,7 @@ class TopologicalSort:
|
||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
|
||||
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
|
||||
82
comfy_extras/nodes_cosmos.py
Normal file
82
comfy_extras/nodes_cosmos.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import nodes
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class EmptyCosmosLatentVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent}, )
|
||||
|
||||
|
||||
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
||||
pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
pixel_len = min(pixels.shape[0], length)
|
||||
padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
|
||||
padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
|
||||
padded_pixels[:pixel_len] = pixels[:pixel_len]
|
||||
latent_len = ((pixel_len - 1) // 8) + 1
|
||||
latent_temp = vae.encode(padded_pixels)
|
||||
return latent_temp[:, :, :latent_len]
|
||||
|
||||
|
||||
class CosmosImageToVideoLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is None and end_image is None:
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (out_latent,)
|
||||
|
||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
||||
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||
|
||||
if end_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
||||
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
||||
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
||||
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
|
||||
}
|
||||
@@ -231,6 +231,24 @@ class FlipSigmas:
|
||||
sigmas[0] = 0.0001
|
||||
return (sigmas,)
|
||||
|
||||
class SetFirstSigma:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"sigmas": ("SIGMAS", ),
|
||||
"sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "set_first_sigma"
|
||||
|
||||
def set_first_sigma(self, sigmas, sigma):
|
||||
sigmas = sigmas.clone()
|
||||
sigmas[0] = sigma
|
||||
return (sigmas, )
|
||||
|
||||
class KSamplerSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -710,6 +728,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
"SetFirstSigma": SetFirstSigma,
|
||||
|
||||
"CFGGuider": CFGGuider,
|
||||
"DualCFGGuider": DualCFGGuider,
|
||||
|
||||
@@ -162,7 +162,7 @@ NOISE_LEVELS = {
|
||||
[14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
||||
],
|
||||
],
|
||||
1.15: [
|
||||
[14.61464119, 0.83188516, 0.02916753],
|
||||
[14.61464119, 1.84880662, 0.59516323, 0.02916753],
|
||||
@@ -246,7 +246,7 @@ NOISE_LEVELS = {
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
],
|
||||
1.35: [
|
||||
[14.61464119, 0.69515091, 0.02916753],
|
||||
[14.61464119, 0.95350921, 0.34370604, 0.02916753],
|
||||
|
||||
@@ -33,7 +33,7 @@ class PairConditioningSetProperties:
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
@@ -47,7 +47,7 @@ class PairConditioningSetProperties:
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
|
||||
class PairConditioningSetPropertiesAndCombine:
|
||||
NodeId = 'PairConditioningSetPropertiesAndCombine'
|
||||
NodeName = 'Cond Pair Set Props Combine'
|
||||
@@ -68,7 +68,7 @@ class PairConditioningSetPropertiesAndCombine:
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
@@ -159,7 +159,7 @@ class PairConditioningCombine:
|
||||
"negative_B": ("CONDITIONING",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
@@ -186,7 +186,7 @@ class PairConditioningSetDefaultAndCombine:
|
||||
"hooks": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
@@ -198,7 +198,7 @@ class PairConditioningSetDefaultAndCombine:
|
||||
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
|
||||
hooks=hooks)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
|
||||
class ConditioningSetDefaultAndCombine:
|
||||
NodeId = 'ConditioningSetDefaultCombine'
|
||||
NodeName = 'Cond Set Default Combine'
|
||||
@@ -224,7 +224,7 @@ class ConditioningSetDefaultAndCombine:
|
||||
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
|
||||
hooks=hooks)
|
||||
return (final_conditioning,)
|
||||
|
||||
|
||||
class SetClipHooks:
|
||||
NodeId = 'SetClipHooks'
|
||||
NodeName = 'Set CLIP Hooks'
|
||||
@@ -240,13 +240,13 @@ class SetClipHooks:
|
||||
"hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
CATEGORY = "advanced/hooks/clip"
|
||||
FUNCTION = "apply_hooks"
|
||||
|
||||
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||
if hooks is not None:
|
||||
clip = clip.clone()
|
||||
if apply_to_conds:
|
||||
@@ -255,7 +255,7 @@ class SetClipHooks:
|
||||
clip.use_clip_schedule = schedule_clip
|
||||
if not clip.use_clip_schedule:
|
||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
||||
clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
@@ -269,7 +269,7 @@ class ConditioningTimestepsRange:
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
|
||||
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
|
||||
@@ -290,7 +290,7 @@ class CreateHookLora:
|
||||
NodeName = 'Create Hook LoRA'
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@@ -303,7 +303,7 @@ class CreateHookLora:
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
@@ -316,7 +316,7 @@ class CreateHookLora:
|
||||
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (prev_hooks,)
|
||||
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
@@ -326,7 +326,7 @@ class CreateHookLora:
|
||||
temp = self.loaded_lora
|
||||
self.loaded_lora = None
|
||||
del temp
|
||||
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
@@ -348,7 +348,7 @@ class CreateHookLoraModelOnly(CreateHookLora):
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
@@ -378,7 +378,7 @@ class CreateHookModelAsLora:
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
@@ -401,7 +401,7 @@ class CreateHookModelAsLora:
|
||||
temp = self.loaded_weights
|
||||
self.loaded_weights = None
|
||||
del temp
|
||||
|
||||
|
||||
if weights_model is None:
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
|
||||
@@ -426,7 +426,7 @@ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
@@ -455,7 +455,7 @@ class SetHookKeyframes:
|
||||
"hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
@@ -481,7 +481,7 @@ class CreateHookKeyframe:
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
@@ -515,7 +515,7 @@ class CreateHookKeyframesInterpolated:
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
@@ -559,7 +559,7 @@ class CreateHookKeyframesFromFloats:
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
@@ -580,7 +580,7 @@ class CreateHookKeyframesFromFloats:
|
||||
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
|
||||
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
|
||||
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||
|
||||
|
||||
is_first = True
|
||||
for percent, strength in zip(percents, floats_strength):
|
||||
guarantee_steps = 0
|
||||
@@ -604,7 +604,7 @@ class SetModelHooksOnCond:
|
||||
"hooks": ("HOOKS",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/manual"
|
||||
@@ -630,7 +630,7 @@ class CombineHooks:
|
||||
"hooks_B": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
@@ -657,7 +657,7 @@ class CombineHooksFour:
|
||||
"hooks_D": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
@@ -690,7 +690,7 @@ class CombineHooksEight:
|
||||
"hooks_H": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
|
||||
@@ -26,6 +26,7 @@ class Load3D():
|
||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
@@ -37,13 +38,22 @@ class Load3D():
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
imagepath = folder_paths.get_annotated_filepath(image)
|
||||
if isinstance(image, dict):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
|
||||
output_image, output_mask = load_image_node.load_image(image=imagepath)
|
||||
|
||||
return output_image, output_mask, model_file,
|
||||
return output_image, output_mask, model_file,
|
||||
else:
|
||||
# to avoid the format is not dict which will happen the FE code is not compatibility to core,
|
||||
# we need to this to double-check, it can be removed after merged FE into the core
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, output_mask = load_image_node.load_image(image=image_path)
|
||||
return output_image, output_mask, model_file,
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@@ -67,6 +77,7 @@ class Load3DAnimation():
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"animation_speed": (["0.1", "0.5", "1", "1.5", "2"], {"default": "1"}),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
@@ -78,13 +89,20 @@ class Load3DAnimation():
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
imagepath = folder_paths.get_annotated_filepath(image)
|
||||
if isinstance(image, dict):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
|
||||
output_image, output_mask = load_image_node.load_image(image=imagepath)
|
||||
|
||||
return output_image, output_mask, model_file,
|
||||
return output_image, output_mask, model_file,
|
||||
else:
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, output_mask = load_image_node.load_image(image=image_path)
|
||||
return output_image, output_mask, model_file,
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
@@ -98,6 +116,7 @@ class Preview3D():
|
||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@@ -121,4 +140,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Load3D": "Load 3D",
|
||||
"Load3DAnimation": "Load 3D - Animation",
|
||||
"Preview3D": "Preview 3D"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ class FeatherMask:
|
||||
output[:, -y, :] *= feather_rate
|
||||
|
||||
return (output,)
|
||||
|
||||
|
||||
class GrowMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@@ -316,7 +316,7 @@ class GrowMask:
|
||||
"tapered_corners": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
@@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
|
||||
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
|
||||
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
}}
|
||||
@@ -206,6 +206,9 @@ class ModelSamplingContinuousEDM:
|
||||
sigma_data = 1.0
|
||||
if sampling == "eps":
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
elif sampling == "edm":
|
||||
sampling_type = comfy.model_sampling.EDM
|
||||
sigma_data = 0.5
|
||||
elif sampling == "v_prediction":
|
||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
elif sampling == "edm_playground_v2.5":
|
||||
|
||||
@@ -46,4 +46,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Morphology": "ImageMorphology",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
# in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
|
||||
# but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly
|
||||
|
||||
|
||||
positive_cond = self.conds.get("positive", None)
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
empty_cond = self.conds.get("empty_negative_prompt", None)
|
||||
@@ -73,7 +73,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
|
||||
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
|
||||
|
||||
# normally this would be done in cfg_function, but we skipped
|
||||
# normally this would be done in cfg_function, but we skipped
|
||||
# that for efficiency: we can compute the noise predictions in
|
||||
# a single call to calc_cond_batch() (rather than two)
|
||||
# so we replicate the hook here
|
||||
|
||||
@@ -40,7 +40,7 @@ class LatentRebatch:
|
||||
return slices, indexable[num * batch_size:]
|
||||
else:
|
||||
return slices, None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def slice_batch(batch, num, batch_size):
|
||||
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||
@@ -81,7 +81,7 @@ class LatentRebatch:
|
||||
if current_batch[0].shape[0] > batch_size:
|
||||
num = current_batch[0].shape[0] // batch_size
|
||||
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||
|
||||
|
||||
for i in range(num):
|
||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||
|
||||
|
||||
@@ -40,9 +40,8 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
hsy, wsx = h // sy, w // sx
|
||||
|
||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||
@@ -50,7 +49,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||
else:
|
||||
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||
|
||||
|
||||
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||
@@ -99,7 +98,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||
src, dst = split(x)
|
||||
n, t1, c = src.shape
|
||||
|
||||
|
||||
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||
|
||||
@@ -30,4 +30,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WebcamCapture": "Webcam Capture",
|
||||
}
|
||||
}
|
||||
|
||||
3
comfyui_version.py
Normal file
3
comfyui_version.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.11"
|
||||
17
execution.py
17
execution.py
@@ -62,7 +62,7 @@ class IsChangedCache:
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
self.init_classic_cache()
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
@@ -93,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
missing_keys = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
input_type, input_category, input_info = get_input_info(class_def, x)
|
||||
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
def mark_missing():
|
||||
missing_keys[x] = True
|
||||
input_data_all[x] = (None,)
|
||||
@@ -138,11 +138,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
max_len_input = 0
|
||||
else:
|
||||
max_len_input = max(len(x) for x in input_data_all.values())
|
||||
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||
|
||||
|
||||
results = []
|
||||
def process_inputs(inputs, index=None, input_is_list=False):
|
||||
if allow_interrupt:
|
||||
@@ -168,7 +168,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
elif max_len_input == 0:
|
||||
process_inputs({})
|
||||
else:
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
process_inputs(input_dict, i)
|
||||
@@ -196,7 +196,6 @@ def merge_result_data(results, obj):
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
@@ -226,14 +225,14 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
results.append(r)
|
||||
subgraph_results.append((None, r))
|
||||
|
||||
|
||||
if has_subgraph:
|
||||
output = subgraph_results
|
||||
elif len(results) > 0:
|
||||
output = merge_result_data(results, obj)
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui, has_subgraph
|
||||
@@ -556,7 +555,7 @@ def validate_inputs(prompt, item, validated):
|
||||
received_types = {}
|
||||
|
||||
for x in valid_inputs:
|
||||
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
||||
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||
assert extra_info is not None
|
||||
if x not in inputs:
|
||||
if input_category == "required":
|
||||
|
||||
@@ -58,7 +58,7 @@ class CacheHelper:
|
||||
if not self.active:
|
||||
return default
|
||||
return self.cache.get(key, default)
|
||||
|
||||
|
||||
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||
if self.active:
|
||||
self.cache[key] = value
|
||||
@@ -305,7 +305,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
|
||||
strong_cache = cache_helper.get(folder_name)
|
||||
if strong_cache is not None:
|
||||
return strong_cache
|
||||
|
||||
|
||||
global filename_list_cache
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
|
||||
6
main.py
6
main.py
@@ -17,7 +17,7 @@ if __name__ == "__main__":
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
|
||||
setup_logger(log_level=args.verbose)
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
def apply_custom_paths():
|
||||
# extra model paths
|
||||
@@ -211,7 +211,9 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
||||
addresses = []
|
||||
for addr in address.split(","):
|
||||
addresses.append((addr, port))
|
||||
await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())
|
||||
await asyncio.gather(
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||
)
|
||||
|
||||
|
||||
def hijack_progress(server_instance):
|
||||
|
||||
70
nodes.py
70
nodes.py
@@ -51,7 +51,7 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
|
||||
}
|
||||
}
|
||||
@@ -65,7 +65,7 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
|
||||
class ConditioningCombine:
|
||||
@classmethod
|
||||
@@ -269,8 +269,8 @@ class VAEDecode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
|
||||
"required": {
|
||||
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
|
||||
"vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."})
|
||||
}
|
||||
}
|
||||
@@ -309,7 +309,7 @@ class VAEDecodeTiled:
|
||||
temporal_compression = vae.temporal_compression_decode()
|
||||
if temporal_compression is not None:
|
||||
temporal_size = max(2, temporal_size // temporal_compression)
|
||||
temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression)
|
||||
temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
|
||||
else:
|
||||
temporal_size = None
|
||||
temporal_overlap = None
|
||||
@@ -550,13 +550,13 @@ class CheckpointLoaderSimple:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
@@ -633,7 +633,7 @@ class LoraLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
||||
@@ -641,7 +641,7 @@ class LoraLoader:
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP")
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
|
||||
FUNCTION = "load_lora"
|
||||
@@ -912,16 +912,19 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
if type == "stable_cascade":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
elif type == "sd3":
|
||||
@@ -937,8 +940,12 @@ class CLIPLoader:
|
||||
else:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return (clip,)
|
||||
|
||||
class DualCLIPLoader:
|
||||
@@ -947,6 +954,9 @@ class DualCLIPLoader:
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@@ -955,7 +965,7 @@ class DualCLIPLoader:
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type):
|
||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
if type == "sdxl":
|
||||
@@ -967,7 +977,11 @@ class DualCLIPLoader:
|
||||
elif type == "hunyuan_video":
|
||||
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return (clip,)
|
||||
|
||||
class CLIPVisionLoader:
|
||||
@@ -1162,7 +1176,7 @@ class EmptyLatentImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"required": {
|
||||
"width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}),
|
||||
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."})
|
||||
@@ -1211,7 +1225,7 @@ class LatentFromBatch:
|
||||
else:
|
||||
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||
return (s,)
|
||||
|
||||
|
||||
class RepeatLatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -1226,7 +1240,7 @@ class RepeatLatentBatch:
|
||||
def repeat(self, samples, amount):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
|
||||
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
@@ -1636,15 +1650,15 @@ class LoadImage:
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
excluded_formats = ['MPO']
|
||||
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
@@ -1655,10 +1669,10 @@ class LoadImage:
|
||||
if len(output_images) == 0:
|
||||
w = image.size[0]
|
||||
h = image.size[1]
|
||||
|
||||
|
||||
if image.size[0] != w or image.size[1] != h:
|
||||
continue
|
||||
|
||||
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
if 'A' in i.getbands():
|
||||
@@ -2047,6 +2061,9 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
||||
EXTENSION_WEB_DIRS = {}
|
||||
|
||||
# Dictionary of successfully loaded module names and associated directories.
|
||||
LOADED_MODULE_DIRS = {}
|
||||
|
||||
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
@@ -2088,6 +2105,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
||||
sys.modules[module_name] = module
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||
|
||||
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
|
||||
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
|
||||
if os.path.isdir(web_dir):
|
||||
@@ -2206,6 +2225,7 @@ def init_builtin_extra_nodes():
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
"nodes_load_3d.py",
|
||||
"nodes_cosmos.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@@ -2234,5 +2254,5 @@ def init_extra_nodes(init_custom_nodes=True):
|
||||
else:
|
||||
logging.warning("Please do a: pip install -r requirements.txt")
|
||||
logging.warning("")
|
||||
|
||||
|
||||
return import_failed
|
||||
|
||||
23
pyproject.toml
Normal file
23
pyproject.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.11"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://www.comfy.org/"
|
||||
repository = "https://github.com/comfyanonymous/ComfyUI"
|
||||
documentation = "https://docs.comfy.org/"
|
||||
|
||||
[tool.ruff]
|
||||
lint.select = [
|
||||
"S307", # suspicious-eval-usage
|
||||
"S102", # exec
|
||||
"T", # print-usage
|
||||
"W",
|
||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
]
|
||||
exclude = ["*.ipynb"]
|
||||
@@ -2,6 +2,7 @@ torch
|
||||
torchsde
|
||||
torchvision
|
||||
torchaudio
|
||||
numpy>=1.25.0
|
||||
einops
|
||||
transformers>=4.28.1
|
||||
tokenizers>=0.13.3
|
||||
|
||||
13
ruff.toml
13
ruff.toml
@@ -1,13 +0,0 @@
|
||||
# Disable all rules by default
|
||||
lint.ignore = ["ALL"]
|
||||
|
||||
# Enable specific rules
|
||||
lint.select = [
|
||||
"S307", # suspicious-eval-usage
|
||||
"T201", # print-usage
|
||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
]
|
||||
|
||||
exclude = ["*.ipynb"]
|
||||
32
server.py
32
server.py
@@ -27,9 +27,11 @@ from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
@@ -43,21 +45,6 @@ async def send_socket_catch_exception(function, message):
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
|
||||
def get_comfyui_version():
|
||||
comfyui_version = "unknown"
|
||||
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||
try:
|
||||
import pygit2
|
||||
repo = pygit2.Repository(repo_path)
|
||||
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||
except Exception:
|
||||
try:
|
||||
import subprocess
|
||||
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to get ComfyUI version: {e}")
|
||||
return comfyui_version.strip()
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
@@ -153,6 +140,7 @@ class PromptServer():
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = None
|
||||
@@ -266,7 +254,7 @@ class PromptServer():
|
||||
|
||||
def compare_image_hash(filepath, image):
|
||||
hasher = node_helpers.hasher()
|
||||
|
||||
|
||||
# function to compare hashes of two images to see if it already exists, fix to #3465
|
||||
if os.path.exists(filepath):
|
||||
a = hasher()
|
||||
@@ -516,7 +504,7 @@ class PromptServer():
|
||||
"os": os.name,
|
||||
"ram_total": ram_total,
|
||||
"ram_free": ram_free,
|
||||
"comfyui_version": get_comfyui_version(),
|
||||
"comfyui_version": __version__,
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
@@ -697,6 +685,7 @@ class PromptServer():
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
@@ -713,6 +702,7 @@ class PromptServer():
|
||||
self.app.add_routes(api_routes)
|
||||
self.app.add_routes(self.routes)
|
||||
|
||||
# Add routes from web extensions.
|
||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||
|
||||
@@ -803,7 +793,7 @@ class PromptServer():
|
||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||
|
||||
async def start_multi_address(self, addresses, call_on_start=None):
|
||||
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
|
||||
runner = web.AppRunner(self.app, access_log=None)
|
||||
await runner.setup()
|
||||
ssl_ctx = None
|
||||
@@ -814,7 +804,8 @@ class PromptServer():
|
||||
keyfile=args.tls_keyfile)
|
||||
scheme = "https"
|
||||
|
||||
logging.info("Starting server\n")
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
for addr in addresses:
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
@@ -830,7 +821,8 @@ class PromptServer():
|
||||
else:
|
||||
address_print = address
|
||||
|
||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||
if verbose:
|
||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||
|
||||
if call_on_start is not None:
|
||||
call_on_start(scheme, self.address, self.port)
|
||||
|
||||
40
tests-unit/app_test/custom_node_manager_test.py
Normal file
40
tests-unit/app_test/custom_node_manager_test.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def custom_node_manager():
|
||||
return CustomNodeManager()
|
||||
|
||||
@pytest.fixture
|
||||
def app(custom_node_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
custom_node_manager.add_routes(routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")])
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
async def test_get_workflow_templates(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
# Setup temporary custom nodes file structure with 1 workflow file
|
||||
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||
example_workflows_dir = custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
|
||||
example_workflows_dir.mkdir(parents=True)
|
||||
template_file = example_workflows_dir / "workflow1.json"
|
||||
template_file.write_text('')
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'custom_nodes': ([str(custom_nodes_dir)], None)
|
||||
}):
|
||||
response = await client.get('/workflow_templates')
|
||||
assert response.status == 200
|
||||
workflows_dict = await response.json()
|
||||
assert isinstance(workflows_dict, dict)
|
||||
assert "ComfyUI-TestExtension1" in workflows_dict
|
||||
assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list)
|
||||
assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"
|
||||
@@ -95,4 +95,4 @@ def test_get_save_image_path(temp_dir):
|
||||
assert filename == "test"
|
||||
assert counter == 1
|
||||
assert subfolder == ""
|
||||
assert filename_prefix == "test"
|
||||
assert filename_prefix == "test"
|
||||
|
||||
@@ -6,8 +6,8 @@ from folder_paths import filter_files_content_types
|
||||
@pytest.fixture(scope="module")
|
||||
def file_extensions():
|
||||
return {
|
||||
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||
}
|
||||
|
||||
@@ -49,4 +49,4 @@ def test_handles_no_extension():
|
||||
|
||||
def test_handles_no_files():
|
||||
files = []
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
|
||||
@@ -20,9 +20,9 @@ def test_list_files_valid_directory(file_service, mock_file_system_ops):
|
||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||
]
|
||||
|
||||
|
||||
result = file_service.list_files("models")
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "file1.txt"
|
||||
assert result[1]["name"] == "dir1"
|
||||
@@ -35,9 +35,9 @@ def test_list_files_invalid_directory(file_service):
|
||||
|
||||
def test_list_files_empty_directory(file_service, mock_file_system_ops):
|
||||
mock_file_system_ops.walk_directory.return_value = []
|
||||
|
||||
|
||||
result = file_service.list_files("models")
|
||||
|
||||
|
||||
assert len(result) == 0
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||
|
||||
@@ -46,9 +46,9 @@ def test_list_files_all_allowed_directories(file_service, mock_file_system_ops,
|
||||
mock_file_system_ops.walk_directory.return_value = [
|
||||
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
|
||||
]
|
||||
|
||||
|
||||
result = file_service.list_files(directory_key)
|
||||
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == f"file_{directory_key}.txt"
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
||||
|
||||
@@ -16,18 +16,18 @@ def temp_directory(tmp_path):
|
||||
|
||||
def test_walk_directory(temp_directory):
|
||||
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
||||
|
||||
|
||||
assert len(result) == 5 # 2 directories and 3 files
|
||||
|
||||
|
||||
files = [item for item in result if item['type'] == 'file']
|
||||
dirs = [item for item in result if item['type'] == 'directory']
|
||||
|
||||
|
||||
assert len(files) == 3
|
||||
assert len(dirs) == 2
|
||||
|
||||
|
||||
file_names = {file['name'] for file in files}
|
||||
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
|
||||
|
||||
|
||||
dir_names = {dir['name'] for dir in dirs}
|
||||
assert dir_names == {'dir1', 'dir2'}
|
||||
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
import pytest
|
||||
import yaml
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import Mock, patch, mock_open
|
||||
|
||||
from utils.extra_config import load_extra_path_config
|
||||
import folder_paths
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def clear_folder_paths():
|
||||
# Clear the global dictionary before each test to ensure isolation
|
||||
original = folder_paths.folder_names_and_paths.copy()
|
||||
folder_paths.folder_names_and_paths.clear()
|
||||
yield
|
||||
folder_paths.folder_names_and_paths = original
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_content():
|
||||
return {
|
||||
@@ -15,10 +26,12 @@ def mock_yaml_content():
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expanded_home():
|
||||
return '/home/user'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yaml_config_with_appdata():
|
||||
return """
|
||||
@@ -27,20 +40,33 @@ def yaml_config_with_appdata():
|
||||
checkpoints: 'models/checkpoints'
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||
return yaml.safe_load(yaml_config_with_appdata)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expandvars_appdata():
|
||||
mock = Mock()
|
||||
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||
|
||||
def expandvars(path):
|
||||
if '%APPDATA%' in path:
|
||||
if sys.platform == 'win32':
|
||||
return path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||
else:
|
||||
return path.replace('%APPDATA%', '/Users/TestUser/AppData/Roaming')
|
||||
return path
|
||||
|
||||
mock.side_effect = expandvars
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_add_model_folder_path():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expanduser(mock_expanded_home):
|
||||
def _expanduser(path):
|
||||
@@ -49,16 +75,18 @@ def mock_expanduser(mock_expanded_home):
|
||||
return path
|
||||
return _expanduser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_safe_load(mock_yaml_content):
|
||||
return Mock(return_value=mock_yaml_content)
|
||||
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||
def test_load_extra_model_paths_expands_userpath(
|
||||
mock_file,
|
||||
monkeypatch,
|
||||
mock_add_model_folder_path,
|
||||
mock_expanduser,
|
||||
mock_add_model_folder_path,
|
||||
mock_expanduser,
|
||||
mock_yaml_safe_load,
|
||||
mock_expanded_home
|
||||
):
|
||||
@@ -75,10 +103,10 @@ def test_load_extra_model_paths_expands_userpath(
|
||||
]
|
||||
|
||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||
|
||||
|
||||
# Check if add_model_folder_path was called with the correct arguments
|
||||
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||
assert actual_call.args[0] == expected_call[0]
|
||||
assert actual_call.args[0] == expected_call[0]
|
||||
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
|
||||
assert actual_call.args[2] == expected_call[2]
|
||||
|
||||
@@ -88,6 +116,7 @@ def test_load_extra_model_paths_expands_userpath(
|
||||
# Check if open was called with the correct file path
|
||||
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open)
|
||||
def test_load_extra_model_paths_expands_appdata(
|
||||
mock_file,
|
||||
@@ -97,7 +126,7 @@ def test_load_extra_model_paths_expands_appdata(
|
||||
yaml_config_with_appdata,
|
||||
mock_yaml_content_appdata
|
||||
):
|
||||
# Set the mock_file to return yaml with appdata as a variable
|
||||
# Set the mock_file to return yaml with appdata as a variable
|
||||
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
||||
|
||||
# Attach mocks
|
||||
@@ -111,16 +140,164 @@ def test_load_extra_model_paths_expands_appdata(
|
||||
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||
load_extra_path_config(dummy_yaml_file_name)
|
||||
|
||||
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||
if sys.platform == "win32":
|
||||
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||
else:
|
||||
expected_base_path = '/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||
expected_calls = [
|
||||
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
||||
]
|
||||
|
||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||
|
||||
|
||||
# Check the base path variable was expanded
|
||||
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||
assert actual_call.args == expected_call
|
||||
|
||||
# Verify that expandvars was called
|
||||
assert mock_expandvars_appdata.called
|
||||
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
|
||||
@patch("yaml.safe_load")
|
||||
def test_load_extra_path_config_relative_base_path(
|
||||
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
|
||||
):
|
||||
"""
|
||||
Test that when 'base_path' is a relative path in the YAML, it is joined to the YAML file directory, and then
|
||||
the items in the config are correctly converted to absolute paths.
|
||||
"""
|
||||
sub_folder = "./my_rel_base"
|
||||
config_data = {
|
||||
"some_model_folder": {
|
||||
"base_path": sub_folder,
|
||||
"is_default": True,
|
||||
"checkpoints": "checkpoints",
|
||||
"some_key": "some_value"
|
||||
}
|
||||
}
|
||||
mock_yaml_load.return_value = config_data
|
||||
|
||||
dummy_yaml_name = "dummy_file.yaml"
|
||||
|
||||
def fake_abspath(path):
|
||||
if path == dummy_yaml_name:
|
||||
# If it's the YAML path, treat it like it lives in tmp_path
|
||||
return os.path.join(str(tmp_path), dummy_yaml_name)
|
||||
return os.path.join(str(tmp_path), path) # Otherwise, do a normal join relative to tmp_path
|
||||
|
||||
def fake_dirname(path):
|
||||
# We expect path to be the result of fake_abspath(dummy_yaml_name)
|
||||
if path.endswith(dummy_yaml_name):
|
||||
return str(tmp_path)
|
||||
return os.path.dirname(path)
|
||||
|
||||
monkeypatch.setattr(os.path, "abspath", fake_abspath)
|
||||
monkeypatch.setattr(os.path, "dirname", fake_dirname)
|
||||
|
||||
load_extra_path_config(dummy_yaml_name)
|
||||
|
||||
expected_checkpoints = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "checkpoints"))
|
||||
expected_some_value = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "some_value"))
|
||||
|
||||
actual_paths = folder_paths.folder_names_and_paths["checkpoints"][0]
|
||||
assert len(actual_paths) == 1, "Should have one path added for 'checkpoints'."
|
||||
assert actual_paths[0] == expected_checkpoints
|
||||
|
||||
actual_paths = folder_paths.folder_names_and_paths["some_key"][0]
|
||||
assert len(actual_paths) == 1, "Should have one path added for 'some_key'."
|
||||
assert actual_paths[0] == expected_some_value
|
||||
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
|
||||
@patch("yaml.safe_load")
|
||||
def test_load_extra_path_config_absolute_base_path(
|
||||
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
|
||||
):
|
||||
"""
|
||||
Test that when 'base_path' is an absolute path, each subdirectory is joined with that absolute path,
|
||||
rather than being relative to the YAML's directory.
|
||||
"""
|
||||
abs_base = os.path.join(str(tmp_path), "abs_base")
|
||||
config_data = {
|
||||
"some_absolute_folder": {
|
||||
"base_path": abs_base, # <-- absolute
|
||||
"is_default": True,
|
||||
"loras": "loras_folder",
|
||||
"embeddings": "embeddings_folder"
|
||||
}
|
||||
}
|
||||
mock_yaml_load.return_value = config_data
|
||||
|
||||
dummy_yaml_name = "dummy_abs.yaml"
|
||||
|
||||
def fake_abspath(path):
|
||||
if path == dummy_yaml_name:
|
||||
# If it's the YAML path, treat it like it is in tmp_path
|
||||
return os.path.join(str(tmp_path), dummy_yaml_name)
|
||||
return path # For absolute base, we just return path directly
|
||||
|
||||
def fake_dirname(path):
|
||||
return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
|
||||
|
||||
monkeypatch.setattr(os.path, "abspath", fake_abspath)
|
||||
monkeypatch.setattr(os.path, "dirname", fake_dirname)
|
||||
|
||||
load_extra_path_config(dummy_yaml_name)
|
||||
|
||||
# Expect the final paths to be <abs_base>/loras_folder and <abs_base>/embeddings_folder
|
||||
expected_loras = os.path.join(abs_base, "loras_folder")
|
||||
expected_embeddings = os.path.join(abs_base, "embeddings_folder")
|
||||
|
||||
actual_loras = folder_paths.folder_names_and_paths["loras"][0]
|
||||
assert len(actual_loras) == 1, "Should have one path for 'loras'."
|
||||
assert actual_loras[0] == os.path.abspath(expected_loras)
|
||||
|
||||
actual_embeddings = folder_paths.folder_names_and_paths["embeddings"][0]
|
||||
assert len(actual_embeddings) == 1, "Should have one path for 'embeddings'."
|
||||
assert actual_embeddings[0] == os.path.abspath(expected_embeddings)
|
||||
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
|
||||
@patch("yaml.safe_load")
|
||||
def test_load_extra_path_config_no_base_path(
|
||||
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
|
||||
):
|
||||
"""
|
||||
Test that if 'base_path' is not present, each path is joined
|
||||
with the directory of the YAML file (unless it's already absolute).
|
||||
"""
|
||||
config_data = {
|
||||
"some_folder_without_base": {
|
||||
"is_default": True,
|
||||
"text_encoders": "clip",
|
||||
"diffusion_models": "unet"
|
||||
}
|
||||
}
|
||||
mock_yaml_load.return_value = config_data
|
||||
|
||||
dummy_yaml_name = "dummy_no_base.yaml"
|
||||
|
||||
def fake_abspath(path):
|
||||
if path == dummy_yaml_name:
|
||||
return os.path.join(str(tmp_path), dummy_yaml_name)
|
||||
return os.path.join(str(tmp_path), path)
|
||||
|
||||
def fake_dirname(path):
|
||||
return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
|
||||
|
||||
monkeypatch.setattr(os.path, "abspath", fake_abspath)
|
||||
monkeypatch.setattr(os.path, "dirname", fake_dirname)
|
||||
|
||||
load_extra_path_config(dummy_yaml_name)
|
||||
|
||||
expected_clip = os.path.join(str(tmp_path), "clip")
|
||||
expected_unet = os.path.join(str(tmp_path), "unet")
|
||||
|
||||
actual_text_encoders = folder_paths.folder_names_and_paths["text_encoders"][0]
|
||||
assert len(actual_text_encoders) == 1, "Should have one path for 'text_encoders'."
|
||||
assert actual_text_encoders[0] == os.path.abspath(expected_clip)
|
||||
|
||||
actual_diffusion = folder_paths.folder_names_and_paths["diffusion_models"][0]
|
||||
assert len(actual_diffusion) == 1, "Should have one path for 'diffusion_models'."
|
||||
assert actual_diffusion[0] == os.path.abspath(expected_unet)
|
||||
|
||||
@@ -22,7 +22,7 @@ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
||||
# rescale the difference image to 0-255 range
|
||||
diff = (diff * 255).astype("uint8")
|
||||
return score, diff
|
||||
|
||||
|
||||
# Metrics must return a tuple of (score, diff_image)
|
||||
METRICS = {"ssim": ssim_score}
|
||||
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
||||
@@ -32,7 +32,7 @@ class TestCompareImageMetrics:
|
||||
@fixture(scope="class")
|
||||
def test_file_names(self, args_pytest):
|
||||
test_dir = args_pytest['test_dir']
|
||||
fnames = self.gather_file_basenames(test_dir)
|
||||
fnames = self.gather_file_basenames(test_dir)
|
||||
yield fnames
|
||||
del fnames
|
||||
|
||||
@@ -64,13 +64,13 @@ class TestCompareImageMetrics:
|
||||
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
||||
grid = self.image_grid(image_list)
|
||||
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
|
||||
|
||||
|
||||
# Tests run for each baseline file name
|
||||
@fixture()
|
||||
def fname(self, baseline_fname):
|
||||
yield baseline_fname
|
||||
del baseline_fname
|
||||
|
||||
|
||||
def test_directories_not_empty(self, args_pytest):
|
||||
baseline_dir = args_pytest['baseline_dir']
|
||||
test_dir = args_pytest['test_dir']
|
||||
@@ -84,7 +84,7 @@ class TestCompareImageMetrics:
|
||||
file_match = self.find_file_match(baseline_file_path, file_paths)
|
||||
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
|
||||
|
||||
# For a baseline image file, finds the corresponding file name in test_dir and
|
||||
# For a baseline image file, finds the corresponding file name in test_dir and
|
||||
# compares the images using the metrics in METRICS
|
||||
@pytest.mark.parametrize("metric", METRICS.keys())
|
||||
def test_pipeline_compare(
|
||||
@@ -98,7 +98,7 @@ class TestCompareImageMetrics:
|
||||
test_dir = args_pytest['test_dir']
|
||||
metrics_output_file = args_pytest['metrics_file']
|
||||
img_output_dir = args_pytest['img_output_dir']
|
||||
|
||||
|
||||
baseline_file_path = os.path.join(baseline_dir, fname)
|
||||
|
||||
# Find file match
|
||||
@@ -108,7 +108,7 @@ class TestCompareImageMetrics:
|
||||
# Run metrics
|
||||
sample_baseline = self.read_img(baseline_file_path)
|
||||
sample_secondary = self.read_img(test_file)
|
||||
|
||||
|
||||
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
|
||||
metric_status = score > METRICS_PASS_THRESHOLD[metric]
|
||||
|
||||
@@ -140,7 +140,7 @@ class TestCompareImageMetrics:
|
||||
|
||||
w, h = img_list[0][0].size
|
||||
grid = Image.new('RGB', size=(cols*w, rows*h))
|
||||
|
||||
|
||||
for i, row in enumerate(img_list):
|
||||
for j, img in enumerate(row):
|
||||
grid.paste(img, box=(j*w, i*h))
|
||||
@@ -170,7 +170,7 @@ class TestCompareImageMetrics:
|
||||
img = Image.open(fname)
|
||||
img.load()
|
||||
return img.info['prompt']
|
||||
|
||||
|
||||
def find_file_match(self, baseline_file: str, file_paths: List[str]):
|
||||
# Find a file in file_paths with matching metadata to baseline_file
|
||||
baseline_prompt = self.read_file_prompt(baseline_file)
|
||||
@@ -181,7 +181,7 @@ class TestCompareImageMetrics:
|
||||
|
||||
# Find file match
|
||||
# Reorder test_file_names so that the file with matching name is first
|
||||
# This is an optimization because matching file names are more likely
|
||||
# This is an optimization because matching file names are more likely
|
||||
# to have matching metadata if they were generated with the same script
|
||||
basename = os.path.basename(baseline_file)
|
||||
file_path_basenames = [os.path.basename(f) for f in file_paths]
|
||||
@@ -192,4 +192,4 @@ class TestCompareImageMetrics:
|
||||
for f in file_paths:
|
||||
test_file_prompt = self.read_file_prompt(f)
|
||||
if baseline_prompt == test_file_prompt:
|
||||
return f
|
||||
return f
|
||||
|
||||
@@ -21,7 +21,7 @@ def args_pytest(pytestconfig):
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
# Modifies items so tests run in the correct order
|
||||
|
||||
|
||||
LAST_TESTS = ['test_quality']
|
||||
|
||||
# Move the last items to the end
|
||||
|
||||
@@ -40,8 +40,8 @@ class ComfyClient:
|
||||
def __init__(self):
|
||||
self.test_name = ""
|
||||
|
||||
def connect(self,
|
||||
listen:str = '127.0.0.1',
|
||||
def connect(self,
|
||||
listen:str = '127.0.0.1',
|
||||
port:Union[str,int] = 8188,
|
||||
client_id: str = str(uuid.uuid4())
|
||||
):
|
||||
@@ -125,7 +125,7 @@ class TestExecution:
|
||||
def _server(self, args_pytest, request):
|
||||
# Start server
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
@@ -368,7 +368,7 @@ class TestExecution:
|
||||
g.node("SaveImage", images=mix1.out(0))
|
||||
g.node("SaveImage", images=mix2.out(0))
|
||||
g.remove_node("removeme")
|
||||
|
||||
|
||||
client.run(g)
|
||||
|
||||
# Add back in the missing node to make sure the error doesn't break the server
|
||||
@@ -512,7 +512,7 @@ class TestExecution:
|
||||
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
|
||||
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||
output = g.node("PreviewImage", images=list_output.out(0))
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ These tests generate and save images through a range of parameters
|
||||
"""
|
||||
|
||||
class ComfyGraph:
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
graph: dict,
|
||||
sampler_nodes: list[str],
|
||||
):
|
||||
@@ -43,12 +43,12 @@ class ComfyGraph:
|
||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||
for node in self.sampler_nodes:
|
||||
self.graph[node]['inputs']['sampler_name'] = sampler_name
|
||||
|
||||
|
||||
def set_scheduler(self, scheduler:str):
|
||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||
for node in self.sampler_nodes:
|
||||
self.graph[node]['inputs']['scheduler'] = scheduler
|
||||
|
||||
|
||||
def set_filename_prefix(self, prefix:str):
|
||||
# sets the filename prefix for the save nodes
|
||||
for node in self.graph:
|
||||
@@ -59,8 +59,8 @@ class ComfyGraph:
|
||||
class ComfyClient:
|
||||
# From examples/websockets_api_example.py
|
||||
|
||||
def connect(self,
|
||||
listen:str = '127.0.0.1',
|
||||
def connect(self,
|
||||
listen:str = '127.0.0.1',
|
||||
port:Union[str,int] = 8188,
|
||||
client_id: str = str(uuid.uuid4())
|
||||
):
|
||||
@@ -152,7 +152,7 @@ class TestInference:
|
||||
def _server(self, args_pytest):
|
||||
# Start server
|
||||
p = subprocess.Popen([
|
||||
'python','main.py',
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
@@ -185,7 +185,7 @@ class TestInference:
|
||||
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
|
||||
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
|
||||
comfy_graph = request.param
|
||||
|
||||
|
||||
# Start client
|
||||
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
||||
|
||||
@@ -201,7 +201,7 @@ class TestInference:
|
||||
def client(self, _client_graph):
|
||||
client = _client_graph[0]
|
||||
yield client
|
||||
|
||||
|
||||
@fixture
|
||||
def comfy_graph(self, _client_graph):
|
||||
# avoid mutating the graph
|
||||
|
||||
@@ -87,7 +87,7 @@ class TestCustomIsChanged:
|
||||
|
||||
def custom_is_changed(self, image, should_change=False):
|
||||
return (image,)
|
||||
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
|
||||
if should_change:
|
||||
@@ -112,7 +112,7 @@ class TestIsChangedWithConstants:
|
||||
|
||||
def custom_is_changed(self, image, value):
|
||||
return (image * value,)
|
||||
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, image, value):
|
||||
if image is None:
|
||||
|
||||
@@ -145,7 +145,7 @@ class TestAccumulationGetLengthNode:
|
||||
|
||||
def accumlength(self, accumulation):
|
||||
return (len(accumulation['accum']),)
|
||||
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationGetItemNode:
|
||||
def __init__(self):
|
||||
@@ -168,7 +168,7 @@ class TestAccumulationGetItemNode:
|
||||
|
||||
def get_item(self, accumulation, index):
|
||||
return (accumulation['accum'][index],)
|
||||
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationSetItemNode:
|
||||
def __init__(self):
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
def load_extra_path_config(yaml_path):
|
||||
with open(yaml_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
|
||||
for c in config:
|
||||
conf = config[c]
|
||||
if conf is None:
|
||||
@@ -14,6 +15,8 @@ def load_extra_path_config(yaml_path):
|
||||
if "base_path" in conf:
|
||||
base_path = conf.pop("base_path")
|
||||
base_path = os.path.expandvars(os.path.expanduser(base_path))
|
||||
if not os.path.isabs(base_path):
|
||||
base_path = os.path.abspath(os.path.join(yaml_dir, base_path))
|
||||
is_default = False
|
||||
if "is_default" in conf:
|
||||
is_default = conf.pop("is_default")
|
||||
@@ -22,10 +25,9 @@ def load_extra_path_config(yaml_path):
|
||||
if len(y) == 0:
|
||||
continue
|
||||
full_path = y
|
||||
if base_path is not None:
|
||||
if base_path:
|
||||
full_path = os.path.join(base_path, full_path)
|
||||
elif not os.path.isabs(full_path):
|
||||
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
|
||||
full_path = os.path.abspath(os.path.join(yaml_dir, y))
|
||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||
|
||||
23
web/assets/BaseViewTemplate-BNGF4K22.js
generated
vendored
Normal file
23
web/assets/BaseViewTemplate-BNGF4K22.js
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
import { d as defineComponent, o as openBlock, f as createElementBlock, J as renderSlot, T as normalizeClass } from "./index-DjNHn37O.js";
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "BaseViewTemplate",
|
||||
props: {
|
||||
dark: { type: Boolean, default: false }
|
||||
},
|
||||
setup(__props) {
|
||||
const props = __props;
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", {
|
||||
class: normalizeClass(["font-sans w-screen h-screen flex items-center justify-center pointer-events-auto overflow-auto", [
|
||||
props.dark ? "text-neutral-300 bg-neutral-900 dark-theme" : "text-neutral-900 bg-neutral-300"
|
||||
]])
|
||||
}, [
|
||||
renderSlot(_ctx.$slots, "default")
|
||||
], 2);
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as _
|
||||
};
|
||||
//# sourceMappingURL=BaseViewTemplate-BNGF4K22.js.map
|
||||
1
web/assets/DownloadGitView-B3f7KHY3.js.map
generated
vendored
1
web/assets/DownloadGitView-B3f7KHY3.js.map
generated
vendored
@@ -1 +0,0 @@
|
||||
{"version":3,"file":"DownloadGitView-B3f7KHY3.js","sources":["../../src/views/DownloadGitView.vue"],"sourcesContent":["<template>\n <div\n class=\"font-sans w-screen h-screen mx-0 grid place-items-center justify-center items-center text-neutral-900 bg-neutral-300 pointer-events-auto\"\n >\n <div\n class=\"col-start-1 h-screen row-start-1 place-content-center mx-auto overflow-y-auto\"\n >\n <div\n class=\"max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding\"\n >\n <!-- Header -->\n <h1 class=\"mt-24 text-4xl font-bold text-red-500\">\n {{ $t('downloadGit.title') }}\n </h1>\n\n <!-- Message -->\n <div class=\"space-y-4\">\n <p class=\"text-xl\">\n {{ $t('downloadGit.message') }}\n </p>\n <p class=\"text-xl\">\n {{ $t('downloadGit.instructions') }}\n </p>\n <p class=\"text-m\">\n {{ $t('downloadGit.warning') }}\n </p>\n </div>\n\n <!-- Actions -->\n <div class=\"flex gap-4 flex-row-reverse\">\n <Button\n :label=\"$t('downloadGit.gitWebsite')\"\n icon=\"pi pi-external-link\"\n icon-pos=\"right\"\n @click=\"openGitDownloads\"\n severity=\"primary\"\n />\n <Button\n :label=\"$t('downloadGit.skip')\"\n icon=\"pi pi-exclamation-triangle\"\n @click=\"skipGit\"\n severity=\"secondary\"\n />\n </div>\n </div>\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { useRouter } from 'vue-router'\n\nconst openGitDownloads = () => {\n window.open('https://git-scm.com/downloads/', '_blank')\n}\n\nconst skipGit = () => {\n console.warn('pushing')\n const router = useRouter()\n router.push('install')\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;AAqDA,UAAM,mBAAmB,6BAAM;AACtB,aAAA,KAAK,kCAAkC,QAAQ;AAAA,IAAA,GAD/B;AAIzB,UAAM,UAAU,6BAAM;AACpB,cAAQ,KAAK,SAAS;AACtB,YAAM,SAAS;AACf,aAAO,KAAK,SAAS;AAAA,IAAA,GAHP;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
44
web/assets/DownloadGitView-B3f7KHY3.js → web/assets/DownloadGitView-DeC7MBzG.js
generated
vendored
44
web/assets/DownloadGitView-B3f7KHY3.js → web/assets/DownloadGitView-DeC7MBzG.js
generated
vendored
@@ -1,15 +1,14 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { a as defineComponent, f as openBlock, g as createElementBlock, A as createBaseVNode, a8 as toDisplayString, h as createVNode, z as unref, D as script, bU as useRouter } from "./index-DIU5yZe9.js";
|
||||
const _hoisted_1 = { class: "font-sans w-screen h-screen mx-0 grid place-items-center justify-center items-center text-neutral-900 bg-neutral-300 pointer-events-auto" };
|
||||
const _hoisted_2 = { class: "col-start-1 h-screen row-start-1 place-content-center mx-auto overflow-y-auto" };
|
||||
const _hoisted_3 = { class: "max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding" };
|
||||
const _hoisted_4 = { class: "mt-24 text-4xl font-bold text-red-500" };
|
||||
const _hoisted_5 = { class: "space-y-4" };
|
||||
const _hoisted_6 = { class: "text-xl" };
|
||||
const _hoisted_7 = { class: "text-xl" };
|
||||
const _hoisted_8 = { class: "text-m" };
|
||||
const _hoisted_9 = { class: "flex gap-4 flex-row-reverse" };
|
||||
import { d as defineComponent, o as openBlock, k as createBlock, M as withCtx, H as createBaseVNode, X as toDisplayString, N as createVNode, j as unref, l as script, bW as useRouter } from "./index-DjNHn37O.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BNGF4K22.js";
|
||||
const _hoisted_1 = { class: "max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding" };
|
||||
const _hoisted_2 = { class: "mt-24 text-4xl font-bold text-red-500" };
|
||||
const _hoisted_3 = { class: "space-y-4" };
|
||||
const _hoisted_4 = { class: "text-xl" };
|
||||
const _hoisted_5 = { class: "text-xl" };
|
||||
const _hoisted_6 = { class: "text-m" };
|
||||
const _hoisted_7 = { class: "flex gap-4 flex-row-reverse" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "DownloadGitView",
|
||||
setup(__props) {
|
||||
@@ -22,16 +21,16 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
router.push("install");
|
||||
}, "skipGit");
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createBaseVNode("div", _hoisted_3, [
|
||||
createBaseVNode("h1", _hoisted_4, toDisplayString(_ctx.$t("downloadGit.title")), 1),
|
||||
createBaseVNode("div", _hoisted_5, [
|
||||
createBaseVNode("p", _hoisted_6, toDisplayString(_ctx.$t("downloadGit.message")), 1),
|
||||
createBaseVNode("p", _hoisted_7, toDisplayString(_ctx.$t("downloadGit.instructions")), 1),
|
||||
createBaseVNode("p", _hoisted_8, toDisplayString(_ctx.$t("downloadGit.warning")), 1)
|
||||
return openBlock(), createBlock(_sfc_main$1, null, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
createBaseVNode("h1", _hoisted_2, toDisplayString(_ctx.$t("downloadGit.title")), 1),
|
||||
createBaseVNode("div", _hoisted_3, [
|
||||
createBaseVNode("p", _hoisted_4, toDisplayString(_ctx.$t("downloadGit.message")), 1),
|
||||
createBaseVNode("p", _hoisted_5, toDisplayString(_ctx.$t("downloadGit.instructions")), 1),
|
||||
createBaseVNode("p", _hoisted_6, toDisplayString(_ctx.$t("downloadGit.warning")), 1)
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_9, [
|
||||
createBaseVNode("div", _hoisted_7, [
|
||||
createVNode(unref(script), {
|
||||
label: _ctx.$t("downloadGit.gitWebsite"),
|
||||
icon: "pi pi-external-link",
|
||||
@@ -47,12 +46,13 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
}, null, 8, ["label"])
|
||||
])
|
||||
])
|
||||
])
|
||||
]);
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=DownloadGitView-B3f7KHY3.js.map
|
||||
//# sourceMappingURL=DownloadGitView-DeC7MBzG.js.map
|
||||
1
web/assets/ExtensionPanel-ByeZ01RF.js.map
generated
vendored
1
web/assets/ExtensionPanel-ByeZ01RF.js.map
generated
vendored
@@ -1 +0,0 @@
|
||||
{"version":3,"file":"ExtensionPanel-ByeZ01RF.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Extension\" class=\"extension-panel\">\n <template #header>\n <SearchBox\n v-model=\"filters['global'].value\"\n :placeholder=\"$t('g.searchExtensions') + '...'\"\n />\n <Message v-if=\"hasChanges\" severity=\"info\" pt:text=\"w-full\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n <div class=\"flex justify-end\">\n <Button\n :label=\"$t('g.reloadToApplyChanges')\"\n @click=\"applyChanges\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n </template>\n <DataTable\n :value=\"extensionStore.extensions\"\n stripedRows\n size=\"small\"\n :filters=\"filters\"\n >\n <Column field=\"name\" :header=\"$t('g.extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport { FilterMatchMode } from '@primevue/core/api'\nimport PanelTemplate from './PanelTemplate.vue'\nimport SearchBox from '@/components/common/SearchBox.vue'\n\nconst filters = ref({\n global: { value: '', matchMode: FilterMatchMode.CONTAINS }\n})\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;AA8DA,UAAM,UAAU,IAAI;AAAA,MAClB,QAAQ,EAAE,OAAO,IAAI,WAAW,gBAAgB,SAAS;AAAA,IAAA,CAC1D;AAED,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
88
web/assets/ExtensionPanel-ByeZ01RF.js → web/assets/ExtensionPanel-D4Phn0Zr.js
generated
vendored
88
web/assets/ExtensionPanel-ByeZ01RF.js → web/assets/ExtensionPanel-D4Phn0Zr.js
generated
vendored
@@ -1,8 +1,9 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { a as defineComponent, r as ref, ck as FilterMatchMode, co as useExtensionStore, u as useSettingStore, o as onMounted, q as computed, f as openBlock, x as createBlock, y as withCtx, h as createVNode, cl as SearchBox, z as unref, bW as script, A as createBaseVNode, g as createElementBlock, Q as renderList, a8 as toDisplayString, ay as createTextVNode, P as Fragment, D as script$1, i as createCommentVNode, c5 as script$3, cm as _sfc_main$1 } from "./index-DIU5yZe9.js";
|
||||
import { s as script$2, a as script$4 } from "./index-D3u7l7ha.js";
|
||||
import "./index-d698Brhb.js";
|
||||
import { d as defineComponent, ab as ref, cn as FilterMatchMode, cs as useExtensionStore, a as useSettingStore, m as onMounted, c as computed, o as openBlock, k as createBlock, M as withCtx, N as createVNode, co as SearchBox, j as unref, bZ as script, H as createBaseVNode, f as createElementBlock, E as renderList, X as toDisplayString, aE as createTextVNode, F as Fragment, l as script$1, I as createCommentVNode, aI as script$3, bO as script$4, c4 as script$5, cp as _sfc_main$1 } from "./index-DjNHn37O.js";
|
||||
import { s as script$2, a as script$6 } from "./index-B5F0uxTQ.js";
|
||||
import "./index-B-aVupP5.js";
|
||||
import "./index-5HFeZax4.js";
|
||||
const _hoisted_1 = { class: "flex justify-end" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ExtensionPanel",
|
||||
@@ -35,9 +36,49 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
...editingDisabledExtensionNames
|
||||
]);
|
||||
}, "updateExtensionStatus");
|
||||
const enableAllExtensions = /* @__PURE__ */ __name(() => {
|
||||
extensionStore.extensions.forEach((ext) => {
|
||||
if (extensionStore.isExtensionReadOnly(ext.name)) return;
|
||||
editingEnabledExtensions.value[ext.name] = true;
|
||||
});
|
||||
updateExtensionStatus();
|
||||
}, "enableAllExtensions");
|
||||
const disableAllExtensions = /* @__PURE__ */ __name(() => {
|
||||
extensionStore.extensions.forEach((ext) => {
|
||||
if (extensionStore.isExtensionReadOnly(ext.name)) return;
|
||||
editingEnabledExtensions.value[ext.name] = false;
|
||||
});
|
||||
updateExtensionStatus();
|
||||
}, "disableAllExtensions");
|
||||
const disableThirdPartyExtensions = /* @__PURE__ */ __name(() => {
|
||||
extensionStore.extensions.forEach((ext) => {
|
||||
if (extensionStore.isCoreExtension(ext.name)) return;
|
||||
editingEnabledExtensions.value[ext.name] = false;
|
||||
});
|
||||
updateExtensionStatus();
|
||||
}, "disableThirdPartyExtensions");
|
||||
const applyChanges = /* @__PURE__ */ __name(() => {
|
||||
window.location.reload();
|
||||
}, "applyChanges");
|
||||
const menu = ref();
|
||||
const contextMenuItems = [
|
||||
{
|
||||
label: "Enable All",
|
||||
icon: "pi pi-check",
|
||||
command: enableAllExtensions
|
||||
},
|
||||
{
|
||||
label: "Disable All",
|
||||
icon: "pi pi-times",
|
||||
command: disableAllExtensions
|
||||
},
|
||||
{
|
||||
label: "Disable 3rd Party",
|
||||
icon: "pi pi-times",
|
||||
command: disableThirdPartyExtensions,
|
||||
disabled: !extensionStore.hasThirdPartyExtensions
|
||||
}
|
||||
];
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createBlock(_sfc_main$1, {
|
||||
value: "Extension",
|
||||
@@ -52,7 +93,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
hasChanges.value ? (openBlock(), createBlock(unref(script), {
|
||||
key: 0,
|
||||
severity: "info",
|
||||
"pt:text": "w-full"
|
||||
"pt:text": "w-full",
|
||||
class: "max-h-96 overflow-y-auto"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("ul", null, [
|
||||
@@ -78,7 +120,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
})) : createCommentVNode("", true)
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$4), {
|
||||
createVNode(unref(script$6), {
|
||||
value: unref(extensionStore).extensions,
|
||||
stripedRows: "",
|
||||
size: "small",
|
||||
@@ -86,19 +128,43 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$2), {
|
||||
field: "name",
|
||||
header: _ctx.$t("g.extensionName"),
|
||||
sortable: ""
|
||||
}, null, 8, ["header"]),
|
||||
sortable: "",
|
||||
field: "name"
|
||||
}, {
|
||||
body: withCtx((slotProps) => [
|
||||
createTextVNode(toDisplayString(slotProps.data.name) + " ", 1),
|
||||
unref(extensionStore).isCoreExtension(slotProps.data.name) ? (openBlock(), createBlock(unref(script$3), {
|
||||
key: 0,
|
||||
value: "Core"
|
||||
})) : createCommentVNode("", true)
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["header"]),
|
||||
createVNode(unref(script$2), { pt: {
|
||||
headerCell: "flex items-center justify-end",
|
||||
bodyCell: "flex items-center justify-end"
|
||||
} }, {
|
||||
header: withCtx(() => [
|
||||
createVNode(unref(script$1), {
|
||||
icon: "pi pi-ellipsis-h",
|
||||
text: "",
|
||||
severity: "secondary",
|
||||
onClick: _cache[1] || (_cache[1] = ($event) => menu.value.show($event))
|
||||
}),
|
||||
createVNode(unref(script$4), {
|
||||
ref_key: "menu",
|
||||
ref: menu,
|
||||
model: contextMenuItems
|
||||
}, null, 512)
|
||||
]),
|
||||
body: withCtx((slotProps) => [
|
||||
createVNode(unref(script$3), {
|
||||
createVNode(unref(script$5), {
|
||||
disabled: unref(extensionStore).isExtensionReadOnly(slotProps.data.name),
|
||||
modelValue: editingEnabledExtensions.value[slotProps.data.name],
|
||||
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
|
||||
onChange: updateExtensionStatus
|
||||
}, null, 8, ["modelValue", "onUpdate:modelValue"])
|
||||
}, null, 8, ["disabled", "modelValue", "onUpdate:modelValue"])
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
@@ -114,4 +180,4 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ExtensionPanel-ByeZ01RF.js.map
|
||||
//# sourceMappingURL=ExtensionPanel-D4Phn0Zr.js.map
|
||||
1
web/assets/GraphView-BWxgNrh6.js.map
generated
vendored
1
web/assets/GraphView-BWxgNrh6.js.map
generated
vendored
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user