Compare commits
128 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b07f116dea | ||
|
|
714f728820 | ||
|
|
92d8d15300 | ||
|
|
89253e9fe5 | ||
|
|
3ea3bc8546 | ||
|
|
8e69e2ddfd | ||
|
|
0270a0b41c | ||
|
|
26c7baf789 | ||
|
|
c37f15f98e | ||
|
|
4bca7367f3 | ||
|
|
b6fefe686b | ||
|
|
fa62287f1f | ||
|
|
0844998db3 | ||
|
|
4ced06b879 | ||
|
|
cb06e9669b | ||
|
|
0c32f82298 | ||
|
|
189da3726d | ||
|
|
9a66bb972d | ||
|
|
ea0f939df3 | ||
|
|
f37551c1d2 | ||
|
|
63023011b9 | ||
|
|
f40076096e | ||
|
|
96d891cb94 | ||
|
|
4553891bbd | ||
|
|
ace899e71a | ||
|
|
aff16532d4 | ||
|
|
b50ab153f9 | ||
|
|
072db3bea6 | ||
|
|
a6deca6d9a | ||
|
|
41c30e92e7 | ||
|
|
f579a740dd | ||
|
|
d37272532c | ||
|
|
12da6ef581 | ||
|
|
29d4384a75 | ||
|
|
c5be423d6b | ||
|
|
b4d3652d88 | ||
|
|
5715be2ca9 | ||
|
|
0d4d9222c6 | ||
|
|
afc85cdeb6 | ||
|
|
acc152b674 | ||
|
|
b07258cef2 | ||
|
|
31e54b7052 | ||
|
|
8c0bae50c3 | ||
|
|
530412cb9d | ||
|
|
61c8c70c6e | ||
|
|
d0399f4343 | ||
|
|
e2919d38b4 | ||
|
|
93c8607d51 | ||
|
|
b3d6ae15b3 | ||
|
|
2e21122aab | ||
|
|
1cd6cd6080 | ||
|
|
d7b4bf21a2 | ||
|
|
042a905c37 | ||
|
|
019c7029ea | ||
|
|
8773ccf74d | ||
|
|
1d5d6586f3 | ||
|
|
35740259de | ||
|
|
ab888e1e0b | ||
|
|
d9f0fcdb0c | ||
|
|
b124256817 | ||
|
|
af4b7c91be | ||
|
|
e57d2282d1 | ||
|
|
4027466c80 | ||
|
|
095d867147 | ||
|
|
caeb27c3a5 | ||
|
|
3d06e1c555 | ||
|
|
43a74c0de1 | ||
|
|
af93c8d1ee | ||
|
|
832e3f5ca3 | ||
|
|
079eccc92a | ||
|
|
b6951768c4 | ||
|
|
fca304debf | ||
|
|
14880e6dba | ||
|
|
f1059b0b82 | ||
|
|
debabccb84 | ||
|
|
37cd448529 | ||
|
|
94f21f9301 | ||
|
|
60653004e5 | ||
|
|
a57d635c5f | ||
|
|
016b219dcc | ||
|
|
8ac2dddeed | ||
|
|
3e880ac709 | ||
|
|
e5ea112a90 | ||
|
|
8d88bfaff9 | ||
|
|
ed4d92b721 | ||
|
|
932ae8d9ca | ||
|
|
44e19a28d3 | ||
|
|
0a0df5f136 | ||
|
|
24d6871e47 | ||
|
|
9e1d301129 | ||
|
|
768e035868 | ||
|
|
669e0497ea | ||
|
|
541dc08547 | ||
|
|
8d8dc9a262 | ||
|
|
2f98c24360 | ||
|
|
ef85058e97 | ||
|
|
f9230bd357 | ||
|
|
537c27cbf3 | ||
|
|
6ff2e4d550 | ||
|
|
222f48c0f2 | ||
|
|
13fd4d6e45 | ||
|
|
1210d094c7 | ||
|
|
255edf2246 | ||
|
|
4f011b9a00 | ||
|
|
67feb05299 | ||
|
|
6d21740346 | ||
|
|
7fbf4b72fe | ||
|
|
14ca5f5a10 | ||
|
|
ce557cfb88 | ||
|
|
96e2a45193 | ||
|
|
dfa2b6d129 | ||
|
|
f3566f0894 | ||
|
|
ca69b41cee | ||
|
|
a058f52090 | ||
|
|
d6bbe8c40f | ||
|
|
a7fe0a94de | ||
|
|
e857dd48b8 | ||
|
|
d303cb5341 | ||
|
|
fb2ad645a3 | ||
|
|
d8a7a32779 | ||
|
|
a00e1489d2 | ||
|
|
ebf038d4fa | ||
|
|
b4de04a1c1 | ||
|
|
b1a02131c9 | ||
|
|
3a3910f91d | ||
|
|
507199d9a8 | ||
|
|
2f3ab40b62 | ||
|
|
7fc3ccdcc2 |
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -12,7 +12,7 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "124"
|
||||
default: "126"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
|
||||
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
||||
2
.github/workflows/test-unit.yml
vendored
2
.github/workflows/test-unit.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "124"
|
||||
default: "126"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "124"
|
||||
default: "126"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
# Frontend assets
|
||||
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
|
||||
|
||||
61
README.md
61
README.md
@@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular diffusion model GUI and backend.**
|
||||
**The most powerful and modular visual AI engine and application.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
@@ -31,10 +31,24 @@
|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
||||
|
||||
## Get Started
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
|
||||
#### [Windows Portable Package](#installing)
|
||||
- Get the latest commits and completely portable.
|
||||
- Available on Windows.
|
||||
|
||||
#### [Manual Install](#manual-install-windows-linux)
|
||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
### [Installing ComfyUI](#installing)
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
@@ -47,11 +61,14 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
@@ -119,7 +136,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
# Installing
|
||||
|
||||
## Windows
|
||||
## Windows Portable
|
||||
|
||||
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
|
||||
|
||||
@@ -129,6 +146,8 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643)
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
@@ -137,9 +156,18 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo
|
||||
|
||||
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
|
||||
|
||||
|
||||
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
||||
|
||||
You can install and start ComfyUI using comfy-cli:
|
||||
```bash
|
||||
pip install comfy-cli
|
||||
comfy install
|
||||
```
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@@ -151,11 +179,11 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
@@ -185,7 +213,7 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
|
||||
@@ -233,6 +261,13 @@ For models compatible with Ascend Extension for PyTorch (torch_npu). To get star
|
||||
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
|
||||
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
|
||||
|
||||
#### Cambricon MLUs
|
||||
|
||||
For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
|
||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
||||
3. Launch ComfyUI by running `python main.py`
|
||||
|
||||
# Running
|
||||
|
||||
@@ -289,6 +324,8 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
||||
|
||||
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
||||
|
||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
@@ -305,7 +342,7 @@ For any bugs, issues, or feature requests related to the frontend, please use th
|
||||
|
||||
The new frontend is now the default for ComfyUI. However, please note:
|
||||
|
||||
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||
1. The frontend in the main ComfyUI repository is updated fortnightly.
|
||||
2. Daily releases are available in the separate frontend repository.
|
||||
|
||||
To use the most up-to-date frontend version:
|
||||
@@ -322,7 +359,7 @@ To use the most up-to-date frontend version:
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||
```
|
||||
|
||||
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from folder_paths import folder_names_and_paths, get_directory_by_type
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
import os
|
||||
|
||||
class InternalRoutes:
|
||||
'''
|
||||
@@ -15,26 +15,10 @@ class InternalRoutes:
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
async def list_files(request):
|
||||
directory_key = request.query.get('directory', '')
|
||||
try:
|
||||
file_list = self.file_service.list_files(directory_key)
|
||||
return web.json_response({"files": file_list})
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
@@ -67,6 +51,20 @@ class InternalRoutes:
|
||||
response[key] = folder_names_and_paths[key][0]
|
||||
return web.json_response(response)
|
||||
|
||||
@self.routes.get('/files/{directory_type}')
|
||||
async def get_files(request: web.Request) -> web.Response:
|
||||
directory_type = request.match_info['directory_type']
|
||||
if directory_type not in ("output", "input", "temp"):
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||
|
||||
class FileService:
|
||||
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||
|
||||
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||
if directory_key not in self.allowed_directories:
|
||||
raise ValueError("Invalid directory key")
|
||||
directory_path: str = self.allowed_directories[directory_key]
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
||||
@@ -4,12 +4,93 @@ import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import json
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from utils.json_util import merge_json_recursive
|
||||
|
||||
|
||||
# Extra locale files to load into main.json
|
||||
EXTRA_LOCALE_FILES = [
|
||||
"nodeDefs.json",
|
||||
"commands.json",
|
||||
"settings.json",
|
||||
]
|
||||
|
||||
|
||||
def safe_load_json_file(file_path: str) -> dict:
|
||||
if not os.path.exists(file_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Error loading {file_path}")
|
||||
return {}
|
||||
|
||||
|
||||
class CustomNodeManager:
|
||||
"""
|
||||
Placeholder to refactor the custom node management features from ComfyUI-Manager.
|
||||
Currently it only contains the custom workflow templates feature.
|
||||
"""
|
||||
@lru_cache(maxsize=1)
|
||||
def build_translations(self):
|
||||
"""Load all custom nodes translations during initialization. Translations are
|
||||
expected to be loaded from `locales/` folder.
|
||||
|
||||
The folder structure is expected to be the following:
|
||||
- custom_nodes/
|
||||
- custom_node_1/
|
||||
- locales/
|
||||
- en/
|
||||
- main.json
|
||||
- commands.json
|
||||
- settings.json
|
||||
|
||||
returned translations are expected to be in the following format:
|
||||
{
|
||||
"en": {
|
||||
"nodeDefs": {...},
|
||||
"commands": {...},
|
||||
"settings": {...},
|
||||
...{other main.json keys}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
translations = {}
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
# Sort glob results for deterministic ordering
|
||||
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
|
||||
locales_dir = os.path.join(custom_node_dir, "locales")
|
||||
if not os.path.exists(locales_dir):
|
||||
continue
|
||||
|
||||
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
|
||||
lang_code = os.path.basename(os.path.dirname(lang_dir))
|
||||
|
||||
if lang_code not in translations:
|
||||
translations[lang_code] = {}
|
||||
|
||||
# Load main.json
|
||||
main_file = os.path.join(lang_dir, "main.json")
|
||||
node_translations = safe_load_json_file(main_file)
|
||||
|
||||
# Load extra locale files
|
||||
for extra_file in EXTRA_LOCALE_FILES:
|
||||
extra_file_path = os.path.join(lang_dir, extra_file)
|
||||
key = extra_file.split(".")[0]
|
||||
json_data = safe_load_json_file(extra_file_path)
|
||||
if json_data:
|
||||
node_translations[key] = json_data
|
||||
|
||||
if node_translations:
|
||||
translations[lang_code] = merge_json_recursive(
|
||||
translations[lang_code], node_translations
|
||||
)
|
||||
|
||||
return translations
|
||||
|
||||
def add_routes(self, routes, webapp, loadedModules):
|
||||
|
||||
@routes.get("/workflow_templates")
|
||||
@@ -18,17 +99,36 @@ class CustomNodeManager:
|
||||
files = [
|
||||
file
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes")
|
||||
for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json'))
|
||||
for file in glob.glob(
|
||||
os.path.join(folder, "*/example_workflows/*.json")
|
||||
)
|
||||
]
|
||||
workflow_templates_dict = {} # custom_nodes folder name -> example workflow names
|
||||
workflow_templates_dict = (
|
||||
{}
|
||||
) # custom_nodes folder name -> example workflow names
|
||||
for file in files:
|
||||
custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file)))
|
||||
custom_nodes_name = os.path.basename(
|
||||
os.path.dirname(os.path.dirname(file))
|
||||
)
|
||||
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
||||
workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name)
|
||||
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
|
||||
workflow_name
|
||||
)
|
||||
return web.json_response(workflow_templates_dict)
|
||||
|
||||
# Serve workflow templates from custom nodes.
|
||||
for module_name, module_dir in loadedModules:
|
||||
workflows_dir = os.path.join(module_dir, 'example_workflows')
|
||||
workflows_dir = os.path.join(module_dir, "example_workflows")
|
||||
if os.path.exists(workflows_dir):
|
||||
webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)])
|
||||
webapp.add_routes(
|
||||
[
|
||||
web.static(
|
||||
"/api/workflow_templates/" + module_name, workflows_dir
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@routes.get("/i18n")
|
||||
async def get_i18n(request):
|
||||
"""Returns translations from all custom nodes' locales folders."""
|
||||
return web.json_response(self.build_translations())
|
||||
|
||||
@@ -43,10 +43,11 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific
|
||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||
|
||||
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
@@ -176,7 +177,9 @@ parser.add_argument(
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
|
||||
|
||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
@@ -188,3 +191,6 @@ if args.windows_standalone_build:
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
if args.force_fp16:
|
||||
args.fp16_unet = True
|
||||
|
||||
@@ -102,9 +102,10 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
|
||||
@@ -66,13 +66,26 @@ class IO(StrEnum):
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
class RemoteInputOptions(TypedDict):
|
||||
route: str
|
||||
"""The route to the remote source."""
|
||||
refresh_button: bool
|
||||
"""Specifies whether to show a refresh button in the UI below the widget."""
|
||||
control_after_refresh: Literal["first", "last"]
|
||||
"""Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
|
||||
timeout: int
|
||||
"""The maximum amount of time to wait for a response from the remote source in milliseconds."""
|
||||
max_retries: int
|
||||
"""The maximum number of retries before aborting the request."""
|
||||
refresh: int
|
||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
|
||||
"""
|
||||
|
||||
default: bool | str | float | int | list | tuple
|
||||
@@ -113,6 +126,14 @@ class InputTypeOptions(TypedDict):
|
||||
# defaultVal: str
|
||||
dynamicPrompts: bool
|
||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||
# class InputTypeCombo(InputTypeOptions):
|
||||
image_upload: bool
|
||||
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
|
||||
image_folder: Literal["input", "output", "temp"]
|
||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||
"""
|
||||
remote: RemoteInputOptions
|
||||
"""Specifies the configuration for a remote input."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
@@ -133,7 +154,7 @@ class HiddenInputTypeDict(TypedDict):
|
||||
class InputTypeDict(TypedDict):
|
||||
"""Provides type hinting for node INPUT_TYPES.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
|
||||
"""
|
||||
|
||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
||||
@@ -143,14 +164,14 @@ class InputTypeDict(TypedDict):
|
||||
hidden: HiddenInputTypeDict
|
||||
"""Offers advanced functionality and server-client communication.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
||||
"""
|
||||
|
||||
|
||||
class ComfyNodeABC(ABC):
|
||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
|
||||
"""
|
||||
|
||||
DESCRIPTION: str
|
||||
@@ -167,7 +188,7 @@ class ComfyNodeABC(ABC):
|
||||
CATEGORY: str
|
||||
"""The category of the node, as per the "Add Node" menu.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
|
||||
"""
|
||||
EXPERIMENTAL: bool
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
@@ -181,9 +202,9 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
|
||||
"""
|
||||
return {"required": {}}
|
||||
|
||||
@@ -198,7 +219,7 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
||||
"""
|
||||
INPUT_IS_LIST: bool
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
@@ -209,7 +230,7 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
@@ -227,7 +248,7 @@ class ComfyNodeABC(ABC):
|
||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
||||
specifying which outputs which should be so treated.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
@@ -237,19 +258,19 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
|
||||
"""
|
||||
|
||||
|
||||
@@ -267,7 +288,7 @@ class CheckLazyMixin:
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
|
||||
@@ -3,9 +3,6 @@ import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -46,7 +43,7 @@ class CONDCrossAttn(CONDRegular):
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
mult_min = math.lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
@@ -57,7 +54,7 @@ class CONDCrossAttn(CONDRegular):
|
||||
crossattn_max_len = self.cond.shape[1]
|
||||
for x in others:
|
||||
c = x.cond
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
|
||||
conds.append(c)
|
||||
|
||||
out = []
|
||||
|
||||
@@ -4,105 +4,6 @@ import logging
|
||||
|
||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||
|
||||
# =================#
|
||||
# UNet Conversion #
|
||||
# =================#
|
||||
|
||||
unet_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
]
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0", "norm1"),
|
||||
("in_layers.2", "conv1"),
|
||||
("out_layers.0", "norm2"),
|
||||
("out_layers.3", "conv2"),
|
||||
("emb_layers.1", "time_emb_proj"),
|
||||
("skip_connection", "conv_shortcut"),
|
||||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
# hardcoded number of downblocks and resnets/attentions...
|
||||
# would need smarter logic for other networks.
|
||||
for i in range(4):
|
||||
# loop over downblocks/upblocks
|
||||
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
def convert_unet_state_dict(unet_state_dict):
|
||||
# buyer beware: this is a *brittle* function,
|
||||
# and correct output requires that all of these pieces interact in
|
||||
# the exact order in which I have arranged them.
|
||||
mapping = {k: k for k in unet_state_dict.keys()}
|
||||
for sd_name, hf_name in unet_conversion_map:
|
||||
mapping[hf_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in unet_conversion_map_layer:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
@@ -213,6 +114,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
|
||||
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||
def cat_tensors(tensors):
|
||||
x = 0
|
||||
@@ -229,6 +131,7 @@ def cat_tensors(tensors):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
@@ -284,5 +187,3 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
|
||||
def convert_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
|
||||
@@ -661,7 +661,7 @@ class UniPC:
|
||||
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||
@@ -669,7 +669,7 @@ class UniPC:
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
|
||||
@@ -40,7 +40,7 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
t = torch.linspace(1, eps_s, n, device=device)
|
||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
||||
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
@@ -1267,7 +1267,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False):
|
||||
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@@ -1289,50 +1289,80 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
else:
|
||||
gamma = 0
|
||||
sigma_hat = sigmas[i]
|
||||
|
||||
if gamma > 0:
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
|
||||
if sigmas[i + 1] == 0 or old_denoised is None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
if sigma_down == 0 or old_denoised is None:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
d = to_d(x, sigma_hat, uncond_denoised)
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
d = to_d(x, sigmas[i], uncond_denoised)
|
||||
x = denoised + d * sigma_down
|
||||
else:
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
h = t_next - t
|
||||
c2 = (t_prev - t) / h
|
||||
|
||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
|
||||
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
|
||||
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
|
||||
|
||||
if cfg_pp:
|
||||
x = x + (denoised - uncond_denoised)
|
||||
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
|
||||
else:
|
||||
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
old_denoised = denoised
|
||||
if cfg_pp:
|
||||
old_denoised = uncond_denoised
|
||||
else:
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False)
|
||||
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
|
||||
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
old_d = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
if i == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Gradient estimation
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@@ -407,3 +407,52 @@ class Cosmos1CV8x8x8(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
|
||||
|
||||
class Wan21(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.1299, -0.1692, 0.2932],
|
||||
[ 0.0671, 0.0406, 0.0442],
|
||||
[ 0.3568, 0.2548, 0.1747],
|
||||
[ 0.0372, 0.2344, 0.1420],
|
||||
[ 0.0313, 0.0189, -0.0328],
|
||||
[ 0.0296, -0.0956, -0.0665],
|
||||
[-0.3477, -0.4059, -0.2925],
|
||||
[ 0.0166, 0.1902, 0.1975],
|
||||
[-0.0412, 0.0267, -0.1364],
|
||||
[-0.1293, 0.0740, 0.1636],
|
||||
[ 0.0680, 0.3019, 0.1128],
|
||||
[ 0.0032, 0.0581, 0.0639],
|
||||
[-0.1251, 0.0927, 0.1699],
|
||||
[ 0.0060, -0.0633, 0.0005],
|
||||
[ 0.3477, 0.2275, 0.2950],
|
||||
[ 0.1984, 0.0913, 0.1861]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
@@ -22,7 +22,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
|
||||
@@ -109,9 +109,8 @@ class Flux(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
@@ -186,7 +185,7 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
@@ -240,9 +240,8 @@ class HunyuanVideo(nn.Module):
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
@@ -311,10 +310,10 @@ class HunyuanVideo(nn.Module):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
|
||||
622
comfy/ldm/lumina/model.py
Normal file
622
comfy/ldm/lumina/model.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
|
||||
#############################################################################
|
||||
# Core NextDiT Model #
|
||||
#############################################################################
|
||||
|
||||
|
||||
class JointAttention(nn.Module):
|
||||
"""Multi-head attention module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
operation_settings={},
|
||||
):
|
||||
"""
|
||||
Initialize the Attention module.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input dimensions.
|
||||
n_heads (int): Number of heads.
|
||||
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
||||
self.n_local_heads = n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
self.qkv = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.out = operation_settings.get("operations").Linear(
|
||||
n_heads * self.head_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
"""
|
||||
|
||||
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x_in.shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
x:
|
||||
x_mask:
|
||||
freqs_cis:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
xq, xk, xv = torch.split(
|
||||
self.qkv(x),
|
||||
[
|
||||
self.n_local_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
||||
|
||||
return self.out(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
operation_settings={},
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple
|
||||
of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
|
||||
dimension. Defaults to None.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.w2 = operation_settings.get("operations").Linear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.w3 = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
# @torch.compile
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
class JointTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
operation_settings={},
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a TransformerBlock.
|
||||
|
||||
Args:
|
||||
layer_id (int): Identifier for the layer.
|
||||
dim (int): Embedding dimension of the input features.
|
||||
n_heads (int): Number of attention heads.
|
||||
n_kv_heads (Optional[int]): Number of attention heads in key and
|
||||
value features (if using GQA), or set to None for the same as
|
||||
query.
|
||||
multiple_of (int):
|
||||
ffn_dim_multiplier (float):
|
||||
norm_eps (float):
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the TransformerBlock.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after applying attention and
|
||||
feedforward layers.
|
||||
|
||||
"""
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
x = x + self.attention_norm2(
|
||||
self.attention(
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of NextDiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
||||
super().__init__()
|
||||
self.norm_final = operation_settings.get("operations").LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.linear = operation_settings.get("operations").Linear(
|
||||
hidden_size,
|
||||
patch_size * patch_size * out_channels,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(hidden_size, 1024),
|
||||
hidden_size,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = self.adaLN_modulation(c)
|
||||
x = modulate(self.norm_final(x), scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class NextDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
dim: int = 4096,
|
||||
n_layers: int = 32,
|
||||
n_refiner_layers: int = 2,
|
||||
n_heads: int = 32,
|
||||
n_kv_heads: Optional[int] = None,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = False,
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (1, 512, 512),
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.x_embedder = operation_settings.get("operations").Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
out_features=dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
||||
operation_settings.get("operations").Linear(
|
||||
cap_feat_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
def unpatchify(
|
||||
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
pH = pW = self.patch_size
|
||||
imgs = []
|
||||
for i in range(x.size(0)):
|
||||
H, W = img_size[i]
|
||||
begin = cap_size[i]
|
||||
end = begin + (H // pH) * (W // pW)
|
||||
imgs.append(
|
||||
x[i][begin:end]
|
||||
.view(H // pH, W // pW, pH, pW, self.out_channels)
|
||||
.permute(4, 0, 2, 1, 3)
|
||||
.flatten(3, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
|
||||
if return_tensor:
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
|
||||
if cap_mask is not None:
|
||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
||||
else:
|
||||
l_effective_cap_len = [num_tokens] * bsz
|
||||
|
||||
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
||||
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
||||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
||||
|
||||
max_seq_len = max(
|
||||
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
||||
)
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
||||
|
||||
# build freqs_cis for cap and image individually
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
# cap_freqs_cis_shape[1] = max_cap_len
|
||||
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
for i in range(bsz):
|
||||
img = x[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_x.append(img)
|
||||
x = flat_x
|
||||
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
||||
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
||||
for i in range(bsz):
|
||||
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
||||
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
||||
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
||||
else:
|
||||
mask = None
|
||||
|
||||
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
|
||||
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
||||
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
||||
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
"""
|
||||
Forward pass of NextDiT.
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of text tokens/features
|
||||
"""
|
||||
|
||||
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
|
||||
return -x
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import math
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
@@ -16,7 +18,11 @@ if model_management.xformers_enabled():
|
||||
import xformers.ops
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
from sageattention import sageattn
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
exit(-1)
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
@@ -24,38 +30,24 @@ ops = comfy.ops.disable_weight_init
|
||||
|
||||
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||
|
||||
def get_attn_precision(attn_precision):
|
||||
def get_attn_precision(attn_precision, current_dtype):
|
||||
if args.dont_upcast_attention:
|
||||
return None
|
||||
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
|
||||
return FORCE_UPCAST_ATTENTION_DTYPE
|
||||
|
||||
if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
|
||||
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
|
||||
return attn_precision
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
||||
@@ -90,7 +82,7 @@ def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
@@ -159,7 +151,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = query.shape
|
||||
@@ -229,7 +221,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
@@ -321,7 +321,7 @@ class SelfAttention(nn.Module):
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
|
||||
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
@@ -297,7 +297,7 @@ def vae_attention():
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
return xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
elif model_management.pytorch_attention_enabled_vae():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
return pytorch_attention
|
||||
else:
|
||||
@@ -702,9 +702,6 @@ class Decoder(nn.Module):
|
||||
padding=1)
|
||||
|
||||
def forward(self, z, **kwargs):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
|
||||
485
comfy/ldm/wan/model.py
Normal file
485
comfy/ldm/wan/model.py
Normal file
@@ -0,0 +1,485 @@
|
||||
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.type(torch.float32)
|
||||
|
||||
# calculation
|
||||
sinusoid = torch.outer(
|
||||
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n * d)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def forward(self, x, context):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
"""
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(context))
|
||||
v = self.v(context)
|
||||
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanI2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
|
||||
|
||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
"""
|
||||
context_img = context[:, :257]
|
||||
context = context[:, 257:]
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(context))
|
||||
v = self.v(context)
|
||||
k_img = self.norm_k_img(self.k_img(context_img))
|
||||
v_img = self.v_img(context_img)
|
||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
|
||||
# output
|
||||
x = x + img_x
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
WAN_CROSSATTENTION_CLASSES = {
|
||||
't2v_cross_attn': WanT2VCrossAttention,
|
||||
'i2v_cross_attn': WanI2VCrossAttention,
|
||||
}
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
||||
eps, operation_settings=operation_settings)
|
||||
self.norm3 = operation_settings.get("operations").LayerNorm(
|
||||
dim, eps,
|
||||
elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
||||
num_heads,
|
||||
(-1, -1),
|
||||
qk_norm,
|
||||
eps, operation_settings=operation_settings)
|
||||
self.norm2 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.ffn = nn.Sequential(
|
||||
operation_settings.get("operations").Linear(dim, ffn_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
|
||||
operation_settings.get("operations").Linear(ffn_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
freqs,
|
||||
context,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
e(Tensor): Shape [B, 6, C]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + e[1]) + e[0],
|
||||
freqs)
|
||||
|
||||
x = x + y * e[2]
|
||||
|
||||
# cross-attention & ffn function
|
||||
def cross_attn_ffn(x, context, e):
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
||||
x = x + y * e[5]
|
||||
return x
|
||||
|
||||
x = cross_attn_ffn(x, context, e)
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.head = operation_settings.get("operations").Linear(dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.empty(1, 2, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
def forward(self, x, e):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
e(Tensor): Shape [B, C]
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
return x
|
||||
|
||||
|
||||
class MLPProj(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, operation_settings={}):
|
||||
super().__init__()
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
operation_settings.get("operations").LayerNorm(in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear(in_dim, in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
def forward(self, image_embeds):
|
||||
clip_extra_context_tokens = self.proj(image_embeds)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class WanModel(torch.nn.Module):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
r"""
|
||||
Initialize the diffusion model backbone.
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
Fixed length for text embeddings
|
||||
in_dim (`int`, *optional*, defaults to 16):
|
||||
Input video channels (C_in)
|
||||
dim (`int`, *optional*, defaults to 2048):
|
||||
Hidden dimension of the transformer
|
||||
ffn_dim (`int`, *optional*, defaults to 8192):
|
||||
Intermediate dimension in feed-forward network
|
||||
freq_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension for sinusoidal time embeddings
|
||||
text_dim (`int`, *optional*, defaults to 4096):
|
||||
Input dimension for text embeddings
|
||||
out_dim (`int`, *optional*, defaults to 16):
|
||||
Output video channels (C_out)
|
||||
num_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads
|
||||
num_layers (`int`, *optional*, defaults to 32):
|
||||
Number of transformer blocks
|
||||
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
||||
Window size for local attention (-1 indicates global attention)
|
||||
qk_norm (`bool`, *optional*, defaults to True):
|
||||
Enable query/key normalization
|
||||
cross_attn_norm (`bool`, *optional*, defaults to False):
|
||||
Enable cross-attention normalization
|
||||
eps (`float`, *optional*, defaults to 1e-6):
|
||||
Epsilon value for normalization layers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
assert model_type in ['t2v', 'i2v']
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = operations.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
|
||||
self.text_embedding = nn.Sequential(
|
||||
operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
|
||||
operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
self.time_embedding = nn.Sequential(
|
||||
operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
|
||||
# blocks
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
|
||||
|
||||
d = dim // num_heads
|
||||
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
||||
|
||||
if model_type == 'i2v':
|
||||
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
|
||||
else:
|
||||
self.img_emb = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (Tensor):
|
||||
List of input video tensors with shape [B, C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [B, L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
clip_fea (Tensor, *optional*):
|
||||
CLIP image features for image-to-video mode
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
freqs=freqs,
|
||||
context=context)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
# return [u.float() for u in x]
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
Reconstruct video tensors from patch embeddings.
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||||
grid_sizes (Tensor):
|
||||
Original spatial-temporal grid dimensions before patching,
|
||||
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
|
||||
c = self.out_dim
|
||||
u = x
|
||||
b = u.shape[0]
|
||||
u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
|
||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||
return u
|
||||
567
comfy/ldm/wan/vae.py
Normal file
567
comfy/ldm/wan/vae.py
Normal file
@@ -0,0 +1,567 @@
|
||||
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class CausalConv3d(ops.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = ops.Conv2d(dim, dim, 1)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=1)
|
||||
x = self.optimized_attention(q, k, v)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
@@ -307,7 +307,6 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
@@ -327,6 +326,13 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
if isinstance(model, comfy.model_base.StableCascade_C):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
|
||||
@@ -34,6 +34,8 @@ import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -148,7 +150,9 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
xc = xc.to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
if context is not None:
|
||||
context = context.to(dtype)
|
||||
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
@@ -163,9 +167,6 @@ class BaseModel(torch.nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
def is_adm(self):
|
||||
return self.adm_channels > 0
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
@@ -549,6 +550,10 @@ class SD_X4Upscaler(BaseModel):
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
@@ -806,7 +811,10 @@ class Flux(BaseModel):
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||
|
||||
guidance = kwargs.get("guidance", 3.5)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
@@ -863,7 +871,19 @@ class HunyuanVideo(BaseModel):
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
if image is not None:
|
||||
padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4])
|
||||
latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype)
|
||||
image_latents = torch.cat([image.to(noise), latent_padding], dim=2)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents))
|
||||
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class CosmosVideo(BaseModel):
|
||||
@@ -892,3 +912,63 @@ class CosmosVideo(BaseModel):
|
||||
latent_image = latent_image + noise
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
if not self.image_to_video:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = self.process_latent_in(image)
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :4]
|
||||
else:
|
||||
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((mask, image), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||
return out
|
||||
|
||||
@@ -136,7 +136,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan_video"
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = [1, 2, 2]
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
@@ -239,7 +239,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["micro_condition"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: # Cosmos
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos"
|
||||
dit_config["max_img_h"] = 240
|
||||
@@ -284,6 +284,42 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["dim"] = 2304
|
||||
dit_config["cap_feat_dim"] = 2304
|
||||
dit_config["n_layers"] = 26
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
||||
dit_config["dim"] = dim
|
||||
dit_config["num_heads"] = dim // 128
|
||||
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["patch_size"] = (1, 2, 2)
|
||||
dit_config["freq_dim"] = 256
|
||||
dit_config["window_size"] = (-1, -1)
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["cross_attn_norm"] = True
|
||||
dit_config["eps"] = 1e-6
|
||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
else:
|
||||
dit_config["model_type"] = "t2v"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
||||
@@ -50,7 +50,9 @@ xpu_available = False
|
||||
torch_version = ""
|
||||
try:
|
||||
torch_version = torch.version.__version__
|
||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||
temp = torch_version.split(".")
|
||||
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
||||
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -93,6 +95,13 @@ try:
|
||||
except:
|
||||
npu_available = False
|
||||
|
||||
try:
|
||||
import torch_mlu # noqa: F401
|
||||
_ = torch.mlu.device_count()
|
||||
mlu_available = torch.mlu.is_available()
|
||||
except:
|
||||
mlu_available = False
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
@@ -110,6 +119,12 @@ def is_ascend_npu():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_mlu():
|
||||
global mlu_available
|
||||
if mlu_available:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@@ -125,6 +140,8 @@ def get_torch_device():
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
elif is_ascend_npu():
|
||||
return torch.device("npu", torch.npu.current_device())
|
||||
elif is_mlu():
|
||||
return torch.device("mlu", torch.mlu.current_device())
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
@@ -151,6 +168,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
_, mem_total_npu = torch.npu.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_npu
|
||||
elif is_mlu():
|
||||
stats = torch.mlu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_mlu
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@@ -218,7 +241,7 @@ def is_amd():
|
||||
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||
if is_nvidia():
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.2
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
@@ -227,22 +250,44 @@ if args.use_pytorch_cross_attention:
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
if int(torch_version[0]) >= 2:
|
||||
if torch_version_numeric[0] >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if is_intel_xpu() or is_ascend_npu():
|
||||
if is_intel_xpu() or is_ascend_npu() or is_mlu():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
|
||||
|
||||
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
||||
try:
|
||||
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
|
||||
if is_nvidia() and args.fast:
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
except:
|
||||
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
||||
@@ -256,15 +301,10 @@ elif args.highvram or args.gpu_only:
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
|
||||
FORCE_FP32 = False
|
||||
FORCE_FP16 = False
|
||||
if args.force_fp32:
|
||||
logging.info("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
|
||||
if args.force_fp16:
|
||||
logging.info("Forcing FP16.")
|
||||
FORCE_FP16 = True
|
||||
|
||||
if lowvram_available:
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
@@ -297,6 +337,8 @@ def get_torch_device_name(device):
|
||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||
elif is_ascend_npu():
|
||||
return "{} {}".format(device, torch.npu.get_device_name(device))
|
||||
elif is_mlu():
|
||||
return "{} {}".format(device, torch.mlu.get_device_name(device))
|
||||
else:
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
@@ -535,14 +577,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 0.1
|
||||
@@ -668,6 +707,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
if model_params * 2 > free_model_memory:
|
||||
return fp8_dtype
|
||||
|
||||
if PRIORITIZE_FP16:
|
||||
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
||||
if torch.float16 in supported_dtypes:
|
||||
@@ -885,6 +928,8 @@ def xformers_enabled():
|
||||
return False
|
||||
if is_ascend_npu():
|
||||
return False
|
||||
if is_mlu():
|
||||
return False
|
||||
if directml_enabled:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
@@ -901,6 +946,11 @@ def pytorch_attention_enabled():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
return ENABLE_PYTORCH_ATTENTION
|
||||
|
||||
def pytorch_attention_enabled_vae():
|
||||
if is_amd():
|
||||
return False # enabling pytorch attention on AMD currently causes crash when doing high res
|
||||
return pytorch_attention_enabled()
|
||||
|
||||
def pytorch_attention_flash_attention():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
@@ -911,6 +961,10 @@ def pytorch_attention_flash_attention():
|
||||
return True
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
if is_mlu():
|
||||
return True
|
||||
if is_amd():
|
||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||
return False
|
||||
|
||||
def mac_version():
|
||||
@@ -923,11 +977,11 @@ def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
|
||||
macos_version = mac_version()
|
||||
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
|
||||
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
|
||||
upcast = True
|
||||
|
||||
if upcast:
|
||||
return torch.float32
|
||||
return {torch.float16: torch.float32}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -957,6 +1011,13 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_npu, _ = torch.npu.mem_get_info(dev)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_npu + mem_free_torch
|
||||
elif is_mlu():
|
||||
stats = torch.mlu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_mlu + mem_free_torch
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@@ -993,21 +1054,26 @@ def is_device_mps(device):
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
def is_directml_enabled():
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
if device is not None:
|
||||
if is_device_cpu(device):
|
||||
return False
|
||||
|
||||
if FORCE_FP16:
|
||||
if args.force_fp16:
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
return False
|
||||
if is_directml_enabled():
|
||||
return True
|
||||
|
||||
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||
return True
|
||||
@@ -1021,6 +1087,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if is_mlu():
|
||||
return True
|
||||
|
||||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
@@ -1078,13 +1147,28 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||
if manual_cast:
|
||||
return True
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
|
||||
if is_mlu():
|
||||
if props.major > 3:
|
||||
return True
|
||||
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
if bf16_works or manual_cast:
|
||||
if bf16_works and manual_cast:
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
@@ -1103,11 +1187,11 @@ def supports_fp8_compute(device=None):
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -96,8 +96,28 @@ def wipe_lowvram_weight(m):
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
|
||||
if hasattr(m, "weight_function"):
|
||||
m.weight_function = []
|
||||
|
||||
if hasattr(m, "bias_function"):
|
||||
m.bias_function = []
|
||||
|
||||
def move_weight_functions(m, device):
|
||||
if device is None:
|
||||
return 0
|
||||
|
||||
memory = 0
|
||||
if hasattr(m, "weight_function"):
|
||||
for f in m.weight_function:
|
||||
if hasattr(f, "move_to"):
|
||||
memory += f.move_to(device=device)
|
||||
|
||||
if hasattr(m, "bias_function"):
|
||||
for f in m.bias_function:
|
||||
if hasattr(f, "move_to"):
|
||||
memory += f.move_to(device=device)
|
||||
return memory
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, patches):
|
||||
@@ -192,11 +212,13 @@ class ModelPatcher:
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.force_cast_weights = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.parent = None
|
||||
|
||||
@@ -250,11 +272,14 @@ class ModelPatcher:
|
||||
n.patches_uuid = self.patches_uuid
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
# attachments
|
||||
n.attachments = {}
|
||||
for k in self.attachments:
|
||||
@@ -402,6 +427,16 @@ class ModelPatcher:
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def set_model_compute_dtype(self, dtype):
|
||||
self.add_object_patch("manual_cast_dtype", dtype)
|
||||
if dtype is not None:
|
||||
self.force_cast_weights = True
|
||||
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
||||
|
||||
def add_weight_wrapper(self, name, function):
|
||||
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
|
||||
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||
"""Retrieves a nested attribute from an object using dot notation considering
|
||||
object patches.
|
||||
@@ -566,6 +601,9 @@ class ModelPatcher:
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
@@ -573,34 +611,46 @@ class ModelPatcher:
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
cast_weight = self.force_cast_weights
|
||||
if lowvram_weight:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.weight_function = []
|
||||
m.bias_function = []
|
||||
|
||||
if weight_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
cast_weight = True
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
if m.comfy_cast_weights:
|
||||
wipe_lowvram_weight(m)
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
|
||||
if weight_key in self.weight_wrapper_patches:
|
||||
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
||||
|
||||
if bias_key in self.weight_wrapper_patches:
|
||||
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
|
||||
|
||||
mem_counter += move_weight_functions(m, device_to)
|
||||
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
@@ -662,6 +712,7 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
self.model.model_lowvram = False
|
||||
@@ -728,15 +779,19 @@ class ModelPatcher:
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
m.to(device_to)
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight:
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
|
||||
@@ -31,6 +31,7 @@ class EPS:
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||
else:
|
||||
@@ -61,9 +62,11 @@ class CONST:
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
|
||||
33
comfy/ops.py
33
comfy/ops.py
@@ -38,21 +38,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = s.bias_function is not None
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
bias = s.bias_function(bias)
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = s.weight_function is not None
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
weight = s.weight_function(weight)
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
return weight, bias
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
weight_function = []
|
||||
bias_function = []
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
@@ -64,7 +66,7 @@ class disable_weight_init:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -78,7 +80,7 @@ class disable_weight_init:
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -92,7 +94,7 @@ class disable_weight_init:
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -106,7 +108,7 @@ class disable_weight_init:
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -120,12 +122,11 @@ class disable_weight_init:
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@@ -139,7 +140,7 @@ class disable_weight_init:
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -160,7 +161,7 @@ class disable_weight_init:
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -181,7 +182,7 @@ class disable_weight_init:
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -199,7 +200,7 @@ class disable_weight_init:
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
if "out_dtype" in kwargs:
|
||||
|
||||
@@ -58,7 +58,6 @@ def convert_cond(cond):
|
||||
temp = c[1].copy()
|
||||
model_conds = temp.get("model_conds", {})
|
||||
if c[0] is not None:
|
||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
temp["uuid"] = uuid.uuid4()
|
||||
|
||||
@@ -12,7 +12,6 @@ import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.samplers
|
||||
import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
@@ -178,7 +177,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
||||
cond = default_conds[i]
|
||||
for x in cond:
|
||||
# do get_area_and_mult to get all the expected values
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
# replace p's mult with calculated mult
|
||||
@@ -215,7 +214,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
@@ -687,7 +686,8 @@ class Sampler:
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"]
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
||||
30
comfy/sd.py
30
comfy/sd.py
@@ -12,6 +12,7 @@ from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import yaml
|
||||
import math
|
||||
|
||||
@@ -36,6 +37,8 @@ import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -391,6 +394,18 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.middle.0.residual.0.gamma" in sd:
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -657,6 +672,8 @@ class CLIPType(Enum):
|
||||
HUNYUAN_VIDEO = 9
|
||||
PIXART = 10
|
||||
COSMOS = 11
|
||||
LUMINA2 = 12
|
||||
WAN = 13
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -675,6 +692,7 @@ class TEModel(Enum):
|
||||
T5_BASE = 6
|
||||
LLAMA3_8 = 7
|
||||
T5_XXL_OLD = 8
|
||||
GEMMA_2_2B = 9
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -693,6 +711,8 @@ def detect_te_model(sd):
|
||||
return TEModel.T5_XXL_OLD
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
@@ -730,6 +750,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
if "text_projection" in clip_data[i]:
|
||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||
|
||||
tokenizer_data = {}
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
@@ -757,6 +778,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.PIXART:
|
||||
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
|
||||
elif clip_type == CLIPType.WAN:
|
||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
@@ -769,6 +794,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.T5_BASE:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
else:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
@@ -798,7 +827,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
|
||||
parameters = 0
|
||||
tokenizer_data = {}
|
||||
for c in clip_data:
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
@@ -421,10 +421,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
self.end_token = None
|
||||
@@ -585,9 +585,14 @@ class SDTokenizer:
|
||||
return {}
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
||||
if name is not None:
|
||||
self.clip_name = name
|
||||
self.clip = "{}".format(self.clip_name)
|
||||
else:
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||
|
||||
@@ -600,7 +605,7 @@ class SD1Tokenizer:
|
||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
return getattr(self, self.clip).state_dict()
|
||||
|
||||
class SD1CheckpointClipModel(SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
|
||||
@@ -15,6 +15,8 @@ import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -865,6 +867,78 @@ class CosmosI2V(CosmosT2V):
|
||||
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V]
|
||||
class Lumina2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "lumina2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 6.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Lumina2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 8.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||
|
||||
class WAN21_I2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "i2v",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -118,7 +118,7 @@ class BertModel_(torch.nn.Module):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
x, i = self.encoder(x, mask, intermediate_output)
|
||||
return x, i
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
@@ -21,15 +20,41 @@ class Llama2Config:
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
vocab_size: int = 256000
|
||||
hidden_size: int = 2304
|
||||
intermediate_size: int = 9216
|
||||
num_hidden_layers: int = 26
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 4
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
transformer_type: str = "gemma2"
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
self.add = add
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
w = self.weight
|
||||
if self.add:
|
||||
w = w + 1.0
|
||||
|
||||
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@@ -68,13 +93,15 @@ class Attention(nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
self.head_dim = config.head_dim
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
|
||||
ops = ops or nn
|
||||
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -84,7 +111,6 @@ class Attention(nn.Module):
|
||||
optimized_attention=None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states)
|
||||
xk = self.k_proj(hidden_states)
|
||||
xv = self.v_proj(hidden_states)
|
||||
@@ -108,9 +134,13 @@ class MLP(nn.Module):
|
||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
if config.mlp_activation == "silu":
|
||||
self.activation = torch.nn.functional.silu
|
||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||
self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh")
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
@@ -146,6 +176,45 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlockGemma2(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
# MLP
|
||||
residual = x
|
||||
x = self.pre_feedforward_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = self.post_feedforward_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
@@ -158,17 +227,27 @@ class Llama2_(nn.Module):
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.normalize_in = True
|
||||
else:
|
||||
transformer = TransformerBlock
|
||||
self.normalize_in = False
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
|
||||
transformer(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
x.shape[1],
|
||||
self.config.rope_theta,
|
||||
device=x.device)
|
||||
@@ -206,16 +285,7 @@ class Llama2_(nn.Module):
|
||||
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Llama2(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Llama2Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class BaseLlama:
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
@@ -224,3 +294,23 @@ class Llama2(torch.nn.Module):
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
|
||||
class Llama2(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Llama2Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma2_2B_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
39
comfy/text_encoders/lumina2.py
Normal file
39
comfy/text_encoders/lumina2.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
|
||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
|
||||
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||
|
||||
|
||||
class Gemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return LuminaTEModel_
|
||||
@@ -1,21 +1,21 @@
|
||||
import torch
|
||||
|
||||
class SPieceTokenizer:
|
||||
add_eos = True
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path):
|
||||
return SPieceTokenizer(path)
|
||||
def from_pretrained(path, **kwargs):
|
||||
return SPieceTokenizer(path, **kwargs)
|
||||
|
||||
def __init__(self, tokenizer_path):
|
||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
import sentencepiece
|
||||
if torch.is_tensor(tokenizer_path):
|
||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||
|
||||
if isinstance(tokenizer_path, bytes):
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||
else:
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||
|
||||
def get_vocab(self):
|
||||
out = {}
|
||||
|
||||
@@ -203,7 +203,7 @@ class T5Stack(torch.nn.Module):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
intermediate = None
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
|
||||
|
||||
22
comfy/text_encoders/umt5_config_xxl.json
Normal file
22
comfy/text_encoders/umt5_config_xxl.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "umt5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 256384
|
||||
}
|
||||
37
comfy/text_encoders/wan.py
Normal file
37
comfy/text_encoders/wan.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class UMT5XXlModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
||||
|
||||
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
|
||||
class WanT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer)
|
||||
|
||||
class WanT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class WanTEModel(WanT5Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return WanTEModel
|
||||
@@ -43,13 +43,23 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
try:
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
except Exception as e:
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
if "HeaderTooLarge" in message:
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
|
||||
if "MetadataIncompleteBuffer" in message:
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||
raise e
|
||||
else:
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
|
||||
@@ -38,7 +38,26 @@ class FluxGuidance:
|
||||
return (c, )
|
||||
|
||||
|
||||
class FluxDisableGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning": ("CONDITIONING", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "advanced/conditioning/flux"
|
||||
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
|
||||
|
||||
def append(self, conditioning):
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
||||
return (c, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||
"FluxGuidance": FluxGuidance,
|
||||
"FluxDisableGuidance": FluxDisableGuidance,
|
||||
}
|
||||
|
||||
@@ -2,10 +2,14 @@ import comfy.utils
|
||||
import comfy_extras.nodes_post_processing
|
||||
import torch
|
||||
|
||||
def reshape_latent_to(target_shape, latent):
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
|
||||
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
||||
latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center")
|
||||
if repeat_batch:
|
||||
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
||||
else:
|
||||
return latent
|
||||
|
||||
|
||||
class LatentAdd:
|
||||
@@ -116,8 +120,7 @@ class LatentBatch:
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
if s1.shape[1:] != s2.shape[1:]:
|
||||
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
|
||||
s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
|
||||
s = torch.cat((s1, s2), dim=0)
|
||||
samples_out["samples"] = s
|
||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||
|
||||
@@ -19,14 +19,8 @@ class Load3D():
|
||||
"image": ("LOAD_3D", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"show_grid": ([True, False],),
|
||||
"camera_type": (["perspective", "orthographic"],),
|
||||
"view": (["front", "right", "top", "isometric"],),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
@@ -38,22 +32,14 @@ class Load3D():
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
if isinstance(image, dict):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
|
||||
return output_image, output_mask, model_file,
|
||||
else:
|
||||
# to avoid the format is not dict which will happen the FE code is not compatibility to core,
|
||||
# we need to this to double-check, it can be removed after merged FE into the core
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, output_mask = load_image_node.load_image(image=image_path)
|
||||
return output_image, output_mask, model_file,
|
||||
return output_image, output_mask, model_file,
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@@ -69,15 +55,8 @@ class Load3DAnimation():
|
||||
"image": ("LOAD_3D_ANIMATION", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"show_grid": ([True, False],),
|
||||
"camera_type": (["perspective", "orthographic"],),
|
||||
"view": (["front", "right", "top", "isometric"],),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"animation_speed": (["0.1", "0.5", "1", "1.5", "2"], {"default": "1"}),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
@@ -89,34 +68,42 @@ class Load3DAnimation():
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
if isinstance(image, dict):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
|
||||
return output_image, output_mask, model_file,
|
||||
else:
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, output_mask = load_image_node.load_image(image=image_path)
|
||||
return output_image, output_mask, model_file,
|
||||
return output_image, output_mask, model_file,
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"show_grid": ([True, False],),
|
||||
"camera_type": (["perspective", "orthographic"],),
|
||||
"view": (["front", "right", "top", "isometric"],),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
RETURN_TYPES = ()
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, model_file, **kwargs):
|
||||
return {"ui": {"model_file": [model_file]}, "result": ()}
|
||||
|
||||
class Preview3DAnimation():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@@ -133,11 +120,13 @@ class Preview3D():
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Load3D": Load3D,
|
||||
"Load3DAnimation": Load3DAnimation,
|
||||
"Preview3D": Preview3D
|
||||
"Preview3D": Preview3D,
|
||||
"Preview3DAnimation": Preview3DAnimation
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Load3D": "Load 3D",
|
||||
"Load3DAnimation": "Load 3D - Animation",
|
||||
"Preview3D": "Preview 3D"
|
||||
"Preview3D": "Preview 3D",
|
||||
"Preview3DAnimation": "Preview 3D - Animation"
|
||||
}
|
||||
|
||||
104
comfy_extras/nodes_lumina2.py
Normal file
104
comfy_extras/nodes_lumina2.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
import torch
|
||||
|
||||
|
||||
class RenormCFG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, cfg_trunc, renorm_cfg):
|
||||
def renorm_cfg_func(args):
|
||||
cond_denoised = args["cond_denoised"]
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
cond_scale = args["cond_scale"]
|
||||
timestep = args["timestep"]
|
||||
x_orig = args["input"]
|
||||
in_channels = model.model.diffusion_model.in_channels
|
||||
|
||||
if timestep[0] < cfg_trunc:
|
||||
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
|
||||
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
|
||||
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
|
||||
half_rest = cond_rest
|
||||
|
||||
if float(renorm_cfg) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(cond_eps
|
||||
, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
|
||||
)
|
||||
max_new_norm = ori_pos_norm * float(renorm_cfg)
|
||||
new_pos_norm = torch.linalg.vector_norm(
|
||||
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
|
||||
)
|
||||
if new_pos_norm >= max_new_norm:
|
||||
half_eps = half_eps * (max_new_norm / new_pos_norm)
|
||||
else:
|
||||
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
|
||||
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
|
||||
half_eps = cond_eps
|
||||
half_rest = cond_rest
|
||||
|
||||
cfg_result = torch.cat([half_eps, half_rest], dim=1)
|
||||
|
||||
# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale
|
||||
|
||||
return x_orig - cfg_result
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(renorm_cfg_func)
|
||||
return (m, )
|
||||
|
||||
|
||||
class CLIPTextEncodeLumina2(ComfyNodeABC):
|
||||
SYSTEM_PROMPT = {
|
||||
"superior": "You are an assistant designed to generate superior images with the superior "\
|
||||
"degree of image-text alignment based on textual prompts or user prompts.",
|
||||
"alignment": "You are an assistant designed to generate high-quality images with the "\
|
||||
"highest degree of image-text alignment based on textual prompts."
|
||||
}
|
||||
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
|
||||
"Superior: You are an assistant designed to generate superior images with the superior "\
|
||||
"degree of image-text alignment based on textual prompts or user prompts. "\
|
||||
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
|
||||
"degree of image-text alignment based on textual prompts."
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}),
|
||||
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = (IO.CONDITIONING,)
|
||||
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||
|
||||
def encode(self, clip, user_prompt, system_prompt):
|
||||
if clip is None:
|
||||
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
|
||||
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt]
|
||||
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
|
||||
tokens = clip.tokenize(prompt)
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2,
|
||||
"RenormCFG": RenormCFG
|
||||
}
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2",
|
||||
}
|
||||
@@ -3,6 +3,8 @@ import comfy.model_sampling
|
||||
import comfy.latent_formats
|
||||
import nodes
|
||||
import torch
|
||||
import node_helpers
|
||||
|
||||
|
||||
class LCM(comfy.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
@@ -294,6 +296,24 @@ class RescaleCFG:
|
||||
m.set_model_sampler_cfg_function(rescale_cfg)
|
||||
return (m, )
|
||||
|
||||
class ModelComputeDtype:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"dtype": (["default", "fp32", "fp16", "bf16"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/debug/model"
|
||||
|
||||
def patch(self, model, dtype):
|
||||
m = model.clone()
|
||||
m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype))
|
||||
return (m, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||
@@ -303,4 +323,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
|
||||
"ModelSamplingFlux": ModelSamplingFlux,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
"ModelComputeDtype": ModelComputeDtype,
|
||||
}
|
||||
|
||||
@@ -196,6 +196,54 @@ class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["extra_pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["affline_norm."] = argument
|
||||
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.block{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["extra_pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["affline_norm."] = argument
|
||||
|
||||
|
||||
for i in range(36):
|
||||
arg_dict["blocks.block{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@@ -206,4 +254,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||
"ModelMergeMochiPreview": ModelMergeMochiPreview,
|
||||
"ModelMergeLTXV": ModelMergeLTXV,
|
||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||
}
|
||||
|
||||
76
comfy_extras/nodes_video.py
Normal file
76
comfy_extras/nodes_video.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
import av
|
||||
import torch
|
||||
import folder_paths
|
||||
import json
|
||||
from fractions import Fraction
|
||||
|
||||
|
||||
class SaveWEBM:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"codec": (["vp9", "av1"],),
|
||||
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/video"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
|
||||
file = f"{filename}_{counter:05}_.webm"
|
||||
container = av.open(os.path.join(full_output_folder, file), mode="w")
|
||||
|
||||
if prompt is not None:
|
||||
container.metadata["prompt"] = json.dumps(prompt)
|
||||
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"}
|
||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||
stream.width = images.shape[-2]
|
||||
stream.height = images.shape[-3]
|
||||
stream.pix_fmt = "yuv420p"
|
||||
stream.bit_rate = 0
|
||||
stream.options = {'crf': str(crf)}
|
||||
|
||||
for frame in images:
|
||||
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
container.mux(stream.encode())
|
||||
container.close()
|
||||
|
||||
results = [{
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
}]
|
||||
|
||||
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveWEBM": SaveWEBM,
|
||||
}
|
||||
54
comfy_extras/nodes_wan.py
Normal file
54
comfy_extras/nodes_wan.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class WanImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 720, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
}
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.12"
|
||||
__version__ = "0.3.18"
|
||||
|
||||
@@ -7,11 +7,18 @@ import logging
|
||||
from typing import Literal
|
||||
from collections.abc import Collection
|
||||
|
||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
from comfy.cli_args import args
|
||||
|
||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
|
||||
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
|
||||
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
# --base-directory - Resets all default paths configured in folder_paths with a new base path
|
||||
if args.base_directory:
|
||||
base_path = os.path.abspath(args.base_directory)
|
||||
else:
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
models_dir = os.path.join(base_path, "models")
|
||||
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
|
||||
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
|
||||
@@ -39,10 +46,10 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
|
||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
user_directory = os.path.join(base_path, "user")
|
||||
|
||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||
|
||||
|
||||
@@ -12,7 +12,10 @@ MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
def preview_to_image(latent_image):
|
||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
||||
)
|
||||
if comfy.model_management.directml_enabled:
|
||||
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
|
||||
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
3
main.py
3
main.py
@@ -138,6 +138,8 @@ import server
|
||||
from server import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
device = comfy.model_management.get_torch_device()
|
||||
@@ -292,6 +294,7 @@ def start_comfyui(asyncio_loop=None):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Running directly, just start ComfyUI.
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
event_loop.run_until_complete(start_all_func())
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import hashlib
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@@ -35,3 +36,11 @@ def hasher():
|
||||
"sha512": hashlib.sha512
|
||||
}
|
||||
return hashfuncs[args.default_hashing_function]
|
||||
|
||||
def string_to_torch_dtype(string):
|
||||
if string == "fp32":
|
||||
return torch.float32
|
||||
if string == "fp16":
|
||||
return torch.float16
|
||||
if string == "bf16":
|
||||
return torch.bfloat16
|
||||
|
||||
52
nodes.py
52
nodes.py
@@ -63,6 +63,8 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||
|
||||
def encode(self, clip, text):
|
||||
if clip is None:
|
||||
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
|
||||
tokens = clip.tokenize(text)
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
@@ -912,7 +914,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -922,7 +924,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
if type == "stable_cascade":
|
||||
@@ -937,6 +939,12 @@ class CLIPLoader:
|
||||
clip_type = comfy.sd.CLIPType.LTXV
|
||||
elif type == "pixart":
|
||||
clip_type = comfy.sd.CLIPType.PIXART
|
||||
elif type == "cosmos":
|
||||
clip_type = comfy.sd.CLIPType.COSMOS
|
||||
elif type == "lumina2":
|
||||
clip_type = comfy.sd.CLIPType.LUMINA2
|
||||
elif type == "wan":
|
||||
clip_type = comfy.sd.CLIPType.WAN
|
||||
else:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
|
||||
@@ -1058,10 +1066,11 @@ class StyleModelApply:
|
||||
for t in conditioning:
|
||||
(txt, keys) = t
|
||||
keys = keys.copy()
|
||||
if strength_type == "attn_bias" and strength != 1.0:
|
||||
# even if the strength is 1.0 (i.e, no change), if there's already a mask, we have to add to it
|
||||
if "attention_mask" in keys or (strength_type == "attn_bias" and strength != 1.0):
|
||||
# math.log raises an error if the argument is zero
|
||||
# torch.log returns -inf, which is what we want
|
||||
attn_bias = torch.log(torch.Tensor([strength]))
|
||||
attn_bias = torch.log(torch.Tensor([strength if strength_type == "attn_bias" else 1.0]))
|
||||
# get the size of the mask image
|
||||
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
|
||||
n_ref = mask_ref_size[0] * mask_ref_size[1]
|
||||
@@ -1756,6 +1765,36 @@ class LoadImageMask:
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class LoadImageOutput(LoadImage):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("COMBO", {
|
||||
"image_upload": True,
|
||||
"image_folder": "output",
|
||||
"remote": {
|
||||
"route": "/internal/files/output",
|
||||
"refresh_button": True,
|
||||
"control_after_refresh": "first",
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
||||
EXPERIMENTAL = True
|
||||
FUNCTION = "load_image_output"
|
||||
|
||||
def load_image_output(self, image):
|
||||
return self.load_image(f"{image} [output]")
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
return True
|
||||
|
||||
|
||||
class ImageScale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
@@ -1942,6 +1981,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"PreviewImage": PreviewImage,
|
||||
"LoadImage": LoadImage,
|
||||
"LoadImageMask": LoadImageMask,
|
||||
"LoadImageOutput": LoadImageOutput,
|
||||
"ImageScale": ImageScale,
|
||||
"ImageScaleBy": ImageScaleBy,
|
||||
"ImageInvert": ImageInvert,
|
||||
@@ -2042,6 +2082,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PreviewImage": "Preview Image",
|
||||
"LoadImage": "Load Image",
|
||||
"LoadImageMask": "Load Image (as Mask)",
|
||||
"LoadImageOutput": "Load Image (from Outputs)",
|
||||
"ImageScale": "Upscale Image",
|
||||
"ImageScaleBy": "Upscale Image By",
|
||||
"ImageUpscaleWithModel": "Upscale Image (using Model)",
|
||||
@@ -2226,6 +2267,9 @@ def init_builtin_extra_nodes():
|
||||
"nodes_hooks.py",
|
||||
"nodes_load_3d.py",
|
||||
"nodes_cosmos.py",
|
||||
"nodes_video.py",
|
||||
"nodes_lumina2.py",
|
||||
"nodes_wan.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.12"
|
||||
version = "0.3.18"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -8,7 +8,8 @@ transformers>=4.28.1
|
||||
tokenizers>=0.13.3
|
||||
sentencepiece
|
||||
safetensors>=0.4.2
|
||||
aiohttp
|
||||
aiohttp>=3.11.8
|
||||
yarl>=1.18.0
|
||||
pyyaml
|
||||
Pillow
|
||||
scipy
|
||||
@@ -19,3 +20,4 @@ psutil
|
||||
kornia>=0.7.1
|
||||
spandrel
|
||||
soundfile
|
||||
av
|
||||
|
||||
26
server.py
26
server.py
@@ -52,6 +52,20 @@ async def cache_control(request: web.Request, handler):
|
||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def compress_body(request: web.Request, handler):
|
||||
accept_encoding = request.headers.get("Accept-Encoding", "")
|
||||
response: web.Response = await handler(request)
|
||||
if not isinstance(response, web.Response):
|
||||
return response
|
||||
if response.content_type not in ["application/json", "text/plain"]:
|
||||
return response
|
||||
if response.body and "gzip" in accept_encoding:
|
||||
response.enable_compression()
|
||||
return response
|
||||
|
||||
|
||||
def create_cors_middleware(allowed_origin: str):
|
||||
@web.middleware
|
||||
async def cors_middleware(request: web.Request, handler):
|
||||
@@ -136,7 +150,8 @@ class PromptServer():
|
||||
PromptServer.instance = self
|
||||
|
||||
mimetypes.init()
|
||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
||||
mimetypes.add_type('image/webp', '.webp')
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
@@ -150,6 +165,9 @@ class PromptServer():
|
||||
self.number = 0
|
||||
|
||||
middlewares = [cache_control]
|
||||
if args.enable_compress_response_body:
|
||||
middlewares.append(compress_body)
|
||||
|
||||
if args.enable_cors_header:
|
||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||
else:
|
||||
@@ -329,6 +347,9 @@ class PromptServer():
|
||||
original_ref = json.loads(post.get("original_ref"))
|
||||
filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
|
||||
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
@@ -370,6 +391,9 @@ class PromptServer():
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
@@ -2,39 +2,146 @@ import pytest
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
import json
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_node_manager():
|
||||
return CustomNodeManager()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(custom_node_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
custom_node_manager.add_routes(routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")])
|
||||
custom_node_manager.add_routes(
|
||||
routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")]
|
||||
)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_get_workflow_templates(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
# Setup temporary custom nodes file structure with 1 workflow file
|
||||
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||
example_workflows_dir = custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
|
||||
example_workflows_dir = (
|
||||
custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
|
||||
)
|
||||
example_workflows_dir.mkdir(parents=True)
|
||||
template_file = example_workflows_dir / "workflow1.json"
|
||||
template_file.write_text('')
|
||||
template_file.write_text("")
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'custom_nodes': ([str(custom_nodes_dir)], None)
|
||||
}):
|
||||
response = await client.get('/workflow_templates')
|
||||
with patch(
|
||||
"folder_paths.folder_names_and_paths",
|
||||
{"custom_nodes": ([str(custom_nodes_dir)], None)},
|
||||
):
|
||||
response = await client.get("/workflow_templates")
|
||||
assert response.status == 200
|
||||
workflows_dict = await response.json()
|
||||
assert isinstance(workflows_dict, dict)
|
||||
assert "ComfyUI-TestExtension1" in workflows_dict
|
||||
assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list)
|
||||
assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"
|
||||
|
||||
|
||||
async def test_build_translations_empty_when_no_locales(custom_node_manager, tmp_path):
|
||||
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||
custom_nodes_dir.mkdir(parents=True)
|
||||
|
||||
with patch("folder_paths.get_folder_paths", return_value=[str(custom_nodes_dir)]):
|
||||
translations = custom_node_manager.build_translations()
|
||||
assert translations == {}
|
||||
|
||||
|
||||
async def test_build_translations_loads_all_files(custom_node_manager, tmp_path):
|
||||
# Setup test directory structure
|
||||
custom_nodes_dir = tmp_path / "custom_nodes" / "test-extension"
|
||||
locales_dir = custom_nodes_dir / "locales" / "en"
|
||||
locales_dir.mkdir(parents=True)
|
||||
|
||||
# Create test translation files
|
||||
main_content = {"title": "Test Extension"}
|
||||
(locales_dir / "main.json").write_text(json.dumps(main_content))
|
||||
|
||||
node_defs = {"node1": "Node 1"}
|
||||
(locales_dir / "nodeDefs.json").write_text(json.dumps(node_defs))
|
||||
|
||||
commands = {"cmd1": "Command 1"}
|
||||
(locales_dir / "commands.json").write_text(json.dumps(commands))
|
||||
|
||||
settings = {"setting1": "Setting 1"}
|
||||
(locales_dir / "settings.json").write_text(json.dumps(settings))
|
||||
|
||||
with patch(
|
||||
"folder_paths.get_folder_paths", return_value=[tmp_path / "custom_nodes"]
|
||||
):
|
||||
translations = custom_node_manager.build_translations()
|
||||
|
||||
assert translations == {
|
||||
"en": {
|
||||
"title": "Test Extension",
|
||||
"nodeDefs": {"node1": "Node 1"},
|
||||
"commands": {"cmd1": "Command 1"},
|
||||
"settings": {"setting1": "Setting 1"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def test_build_translations_handles_invalid_json(custom_node_manager, tmp_path):
|
||||
# Setup test directory structure
|
||||
custom_nodes_dir = tmp_path / "custom_nodes" / "test-extension"
|
||||
locales_dir = custom_nodes_dir / "locales" / "en"
|
||||
locales_dir.mkdir(parents=True)
|
||||
|
||||
# Create valid main.json
|
||||
main_content = {"title": "Test Extension"}
|
||||
(locales_dir / "main.json").write_text(json.dumps(main_content))
|
||||
|
||||
# Create invalid JSON file
|
||||
(locales_dir / "nodeDefs.json").write_text("invalid json{")
|
||||
|
||||
with patch(
|
||||
"folder_paths.get_folder_paths", return_value=[tmp_path / "custom_nodes"]
|
||||
):
|
||||
translations = custom_node_manager.build_translations()
|
||||
|
||||
assert translations == {
|
||||
"en": {
|
||||
"title": "Test Extension",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def test_build_translations_merges_multiple_extensions(
|
||||
custom_node_manager, tmp_path
|
||||
):
|
||||
# Setup test directory structure for two extensions
|
||||
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||
ext1_dir = custom_nodes_dir / "extension1" / "locales" / "en"
|
||||
ext2_dir = custom_nodes_dir / "extension2" / "locales" / "en"
|
||||
ext1_dir.mkdir(parents=True)
|
||||
ext2_dir.mkdir(parents=True)
|
||||
|
||||
# Create translation files for extension 1
|
||||
ext1_main = {"title": "Extension 1", "shared": "Original"}
|
||||
(ext1_dir / "main.json").write_text(json.dumps(ext1_main))
|
||||
|
||||
# Create translation files for extension 2
|
||||
ext2_main = {"description": "Extension 2", "shared": "Override"}
|
||||
(ext2_dir / "main.json").write_text(json.dumps(ext2_main))
|
||||
|
||||
with patch("folder_paths.get_folder_paths", return_value=[str(custom_nodes_dir)]):
|
||||
translations = custom_node_manager.build_translations()
|
||||
|
||||
assert translations == {
|
||||
"en": {
|
||||
"title": "Extension 1",
|
||||
"description": "Extension 2",
|
||||
"shared": "Override", # Second extension should override first
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
||||
# TODO(yoland): clean up this after I get back down
|
||||
import sys
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
from importlib import reload
|
||||
|
||||
import folder_paths
|
||||
import comfy.cli_args
|
||||
from comfy.options import enable_args_parsing
|
||||
enable_args_parsing()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def clear_folder_paths():
|
||||
# Clear the global dictionary before each test to ensure isolation
|
||||
original = folder_paths.folder_names_and_paths.copy()
|
||||
folder_paths.folder_names_and_paths.clear()
|
||||
# Reload the module after each test to ensure isolation
|
||||
yield
|
||||
folder_paths.folder_names_and_paths = original
|
||||
reload(folder_paths)
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
@@ -21,7 +25,21 @@ def temp_dir():
|
||||
yield tmpdirname
|
||||
|
||||
|
||||
def test_get_directory_by_type():
|
||||
@pytest.fixture
|
||||
def set_base_dir():
|
||||
def _set_base_dir(base_dir):
|
||||
# Mock CLI args
|
||||
with patch.object(sys, 'argv', ["main.py", "--base-directory", base_dir]):
|
||||
reload(comfy.cli_args)
|
||||
reload(folder_paths)
|
||||
yield _set_base_dir
|
||||
# Reload the modules after each test to ensure isolation
|
||||
with patch.object(sys, 'argv', ["main.py"]):
|
||||
reload(comfy.cli_args)
|
||||
reload(folder_paths)
|
||||
|
||||
|
||||
def test_get_directory_by_type(clear_folder_paths):
|
||||
test_dir = "/test/dir"
|
||||
folder_paths.set_output_directory(test_dir)
|
||||
assert folder_paths.get_directory_by_type("output") == test_dir
|
||||
@@ -96,3 +114,49 @@ def test_get_save_image_path(temp_dir):
|
||||
assert counter == 1
|
||||
assert subfolder == ""
|
||||
assert filename_prefix == "test"
|
||||
|
||||
|
||||
def test_base_path_changes(set_base_dir):
|
||||
test_dir = os.path.abspath("/test/dir")
|
||||
set_base_dir(test_dir)
|
||||
|
||||
assert folder_paths.base_path == test_dir
|
||||
assert folder_paths.models_dir == os.path.join(test_dir, "models")
|
||||
assert folder_paths.input_directory == os.path.join(test_dir, "input")
|
||||
assert folder_paths.output_directory == os.path.join(test_dir, "output")
|
||||
assert folder_paths.temp_directory == os.path.join(test_dir, "temp")
|
||||
assert folder_paths.user_directory == os.path.join(test_dir, "user")
|
||||
|
||||
assert os.path.join(test_dir, "custom_nodes") in folder_paths.get_folder_paths("custom_nodes")
|
||||
|
||||
for name in ["checkpoints", "loras", "vae", "configs", "embeddings", "controlnet", "classifiers"]:
|
||||
assert folder_paths.get_folder_paths(name)[0] == os.path.join(test_dir, "models", name)
|
||||
|
||||
|
||||
def test_base_path_change_clears_old(set_base_dir):
|
||||
test_dir = os.path.abspath("/test/dir")
|
||||
set_base_dir(test_dir)
|
||||
|
||||
assert len(folder_paths.get_folder_paths("custom_nodes")) == 1
|
||||
|
||||
single_model_paths = [
|
||||
"checkpoints",
|
||||
"loras",
|
||||
"vae",
|
||||
"configs",
|
||||
"clip_vision",
|
||||
"style_models",
|
||||
"diffusers",
|
||||
"vae_approx",
|
||||
"gligen",
|
||||
"upscale_models",
|
||||
"embeddings",
|
||||
"hypernetworks",
|
||||
"photomaker",
|
||||
"classifiers",
|
||||
]
|
||||
for name in single_model_paths:
|
||||
assert len(folder_paths.get_folder_paths(name)) == 1
|
||||
|
||||
for name in ["controlnet", "diffusion_models", "text_encoders"]:
|
||||
assert len(folder_paths.get_folder_paths(name)) == 2
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from unittest.mock import MagicMock, patch
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from api_server.services.file_service import FileService
|
||||
from folder_paths import models_dir, user_directory, output_directory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def internal_routes():
|
||||
return InternalRoutes(None)
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client_factory(aiohttp_client, internal_routes):
|
||||
async def _get_client():
|
||||
app = internal_routes.get_app()
|
||||
return await aiohttp_client(app)
|
||||
return _get_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_valid_directory(aiohttp_client_factory, internal_routes):
|
||||
mock_file_list = [
|
||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||
]
|
||||
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=models')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert 'files' in data
|
||||
assert len(data['files']) == 2
|
||||
assert data['files'] == mock_file_list
|
||||
|
||||
# Check other valid directories
|
||||
resp = await client.get('/files?directory=user')
|
||||
assert resp.status == 200
|
||||
resp = await client.get('/files?directory=output')
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_invalid_directory(aiohttp_client_factory, internal_routes):
|
||||
internal_routes.file_service.list_files = MagicMock(side_effect=ValueError("Invalid directory key"))
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=invalid')
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert 'error' in data
|
||||
assert data['error'] == "Invalid directory key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_exception(aiohttp_client_factory, internal_routes):
|
||||
internal_routes.file_service.list_files = MagicMock(side_effect=Exception("Unexpected error"))
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=models')
|
||||
assert resp.status == 500
|
||||
data = await resp.json()
|
||||
assert 'error' in data
|
||||
assert data['error'] == "Unexpected error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_no_directory_param(aiohttp_client_factory, internal_routes):
|
||||
mock_file_list = []
|
||||
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert 'files' in data
|
||||
assert len(data['files']) == 0
|
||||
|
||||
def test_setup_routes(internal_routes):
|
||||
internal_routes.setup_routes()
|
||||
routes = internal_routes.routes
|
||||
assert any(route.method == 'GET' and str(route.path) == '/files' for route in routes)
|
||||
|
||||
def test_get_app(internal_routes):
|
||||
app = internal_routes.get_app()
|
||||
assert isinstance(app, web.Application)
|
||||
assert internal_routes._app is not None
|
||||
|
||||
def test_get_app_reuse(internal_routes):
|
||||
app1 = internal_routes.get_app()
|
||||
app2 = internal_routes.get_app()
|
||||
assert app1 is app2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
|
||||
client = await aiohttp_client_factory()
|
||||
try:
|
||||
resp = await client.get('/files')
|
||||
print(f"Response received: status {resp.status}") # noqa: T201
|
||||
except Exception as e:
|
||||
print(f"Exception occurred during GET request: {e}") # noqa: T201
|
||||
raise
|
||||
|
||||
assert resp.status != 404, "Route /files does not exist"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_service_initialization():
|
||||
with patch('api_server.routes.internal.internal_routes.FileService') as MockFileService:
|
||||
# Create a mock instance
|
||||
mock_file_service_instance = MagicMock(spec=FileService)
|
||||
MockFileService.return_value = mock_file_service_instance
|
||||
internal_routes = InternalRoutes(None)
|
||||
|
||||
# Check if FileService was initialized with the correct parameters
|
||||
MockFileService.assert_called_once_with({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
|
||||
# Verify that the file_service attribute of InternalRoutes is set
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
||||
@@ -1,54 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from api_server.services.file_service import FileService
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_system_ops():
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def file_service(mock_file_system_ops):
|
||||
allowed_directories = {
|
||||
"models": "/path/to/models",
|
||||
"user": "/path/to/user",
|
||||
"output": "/path/to/output"
|
||||
}
|
||||
return FileService(allowed_directories, file_system_ops=mock_file_system_ops)
|
||||
|
||||
def test_list_files_valid_directory(file_service, mock_file_system_ops):
|
||||
mock_file_system_ops.walk_directory.return_value = [
|
||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||
]
|
||||
|
||||
result = file_service.list_files("models")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "file1.txt"
|
||||
assert result[1]["name"] == "dir1"
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||
|
||||
def test_list_files_invalid_directory(file_service):
|
||||
# Does not support walking directories outside of the allowed directories
|
||||
with pytest.raises(ValueError, match="Invalid directory key"):
|
||||
file_service.list_files("invalid_key")
|
||||
|
||||
def test_list_files_empty_directory(file_service, mock_file_system_ops):
|
||||
mock_file_system_ops.walk_directory.return_value = []
|
||||
|
||||
result = file_service.list_files("models")
|
||||
|
||||
assert len(result) == 0
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||
|
||||
@pytest.mark.parametrize("directory_key", ["models", "user", "output"])
|
||||
def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, directory_key):
|
||||
mock_file_system_ops.walk_directory.return_value = [
|
||||
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
|
||||
]
|
||||
|
||||
result = file_service.list_files(directory_key)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == f"file_{directory_key}.txt"
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
||||
@@ -114,7 +114,7 @@ def test_load_extra_model_paths_expands_userpath(
|
||||
mock_yaml_safe_load.assert_called_once()
|
||||
|
||||
# Check if open was called with the correct file path
|
||||
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r', encoding='utf-8')
|
||||
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open)
|
||||
@@ -145,7 +145,7 @@ def test_load_extra_model_paths_expands_appdata(
|
||||
else:
|
||||
expected_base_path = '/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||
expected_calls = [
|
||||
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
||||
('checkpoints', os.path.normpath(os.path.join(expected_base_path, 'models/checkpoints')), False),
|
||||
]
|
||||
|
||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||
@@ -197,8 +197,8 @@ def test_load_extra_path_config_relative_base_path(
|
||||
|
||||
load_extra_path_config(dummy_yaml_name)
|
||||
|
||||
expected_checkpoints = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "checkpoints"))
|
||||
expected_some_value = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "some_value"))
|
||||
expected_checkpoints = os.path.abspath(os.path.join(str(tmp_path), "my_rel_base", "checkpoints"))
|
||||
expected_some_value = os.path.abspath(os.path.join(str(tmp_path), "my_rel_base", "some_value"))
|
||||
|
||||
actual_paths = folder_paths.folder_names_and_paths["checkpoints"][0]
|
||||
assert len(actual_paths) == 1, "Should have one path added for 'checkpoints'."
|
||||
|
||||
71
tests-unit/utils/json_util_test.py
Normal file
71
tests-unit/utils/json_util_test.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from utils.json_util import merge_json_recursive
|
||||
|
||||
|
||||
def test_merge_simple_dicts():
|
||||
base = {"a": 1, "b": 2}
|
||||
update = {"b": 3, "c": 4}
|
||||
expected = {"a": 1, "b": 3, "c": 4}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_nested_dicts():
|
||||
base = {"a": {"x": 1, "y": 2}, "b": 3}
|
||||
update = {"a": {"y": 4, "z": 5}}
|
||||
expected = {"a": {"x": 1, "y": 4, "z": 5}, "b": 3}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_lists():
|
||||
base = {"a": [1, 2], "b": 3}
|
||||
update = {"a": [3, 4]}
|
||||
expected = {"a": [1, 2, 3, 4], "b": 3}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_nested_lists():
|
||||
base = {"a": {"x": [1, 2]}}
|
||||
update = {"a": {"x": [3, 4]}}
|
||||
expected = {"a": {"x": [1, 2, 3, 4]}}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_mixed_types():
|
||||
base = {"a": [1, 2], "b": {"x": 1}}
|
||||
update = {"a": [3], "b": {"y": 2}}
|
||||
expected = {"a": [1, 2, 3], "b": {"x": 1, "y": 2}}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_overwrite_non_dict():
|
||||
base = {"a": 1}
|
||||
update = {"a": {"x": 2}}
|
||||
expected = {"a": {"x": 2}}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_empty_dicts():
|
||||
base = {}
|
||||
update = {"a": 1}
|
||||
expected = {"a": 1}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_none_values():
|
||||
base = {"a": None}
|
||||
update = {"a": {"x": 1}}
|
||||
expected = {"a": {"x": 1}}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_different_types():
|
||||
base = {"a": [1, 2]}
|
||||
update = {"a": "string"}
|
||||
expected = {"a": "string"}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
|
||||
|
||||
def test_merge_complex_nested():
|
||||
base = {"a": [1, 2], "b": {"x": [3, 4], "y": {"p": 1}}}
|
||||
update = {"a": [5], "b": {"x": [6], "y": {"q": 2}}}
|
||||
expected = {"a": [1, 2, 5], "b": {"x": [3, 4, 6], "y": {"p": 1, "q": 2}}}
|
||||
assert merge_json_recursive(base, update) == expected
|
||||
@@ -4,7 +4,7 @@ import folder_paths
|
||||
import logging
|
||||
|
||||
def load_extra_path_config(yaml_path):
|
||||
with open(yaml_path, 'r') as stream:
|
||||
with open(yaml_path, 'r', encoding='utf-8') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
|
||||
for c in config:
|
||||
@@ -29,5 +29,6 @@ def load_extra_path_config(yaml_path):
|
||||
full_path = os.path.join(base_path, full_path)
|
||||
elif not os.path.isabs(full_path):
|
||||
full_path = os.path.abspath(os.path.join(yaml_dir, y))
|
||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||
normalized_path = os.path.normpath(full_path)
|
||||
logging.info("Adding extra search path {} {}".format(x, normalized_path))
|
||||
folder_paths.add_model_folder_path(x, normalized_path, is_default)
|
||||
|
||||
26
utils/json_util.py
Normal file
26
utils/json_util.py
Normal file
@@ -0,0 +1,26 @@
|
||||
def merge_json_recursive(base, update):
|
||||
"""Recursively merge two JSON-like objects.
|
||||
- Dictionaries are merged recursively
|
||||
- Lists are concatenated
|
||||
- Other types are overwritten by the update value
|
||||
|
||||
Args:
|
||||
base: Base JSON-like object
|
||||
update: Update JSON-like object to merge into base
|
||||
|
||||
Returns:
|
||||
Merged JSON-like object
|
||||
"""
|
||||
if not isinstance(base, dict) or not isinstance(update, dict):
|
||||
if isinstance(base, list) and isinstance(update, list):
|
||||
return base + update
|
||||
return update
|
||||
|
||||
merged = base.copy()
|
||||
for key, value in update.items():
|
||||
if key in merged:
|
||||
merged[key] = merge_json_recursive(merged[key], value)
|
||||
else:
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
23
web/assets/BaseViewTemplate-BNGF4K22.js
generated
vendored
23
web/assets/BaseViewTemplate-BNGF4K22.js
generated
vendored
@@ -1,23 +0,0 @@
|
||||
import { d as defineComponent, o as openBlock, f as createElementBlock, J as renderSlot, T as normalizeClass } from "./index-DjNHn37O.js";
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "BaseViewTemplate",
|
||||
props: {
|
||||
dark: { type: Boolean, default: false }
|
||||
},
|
||||
setup(__props) {
|
||||
const props = __props;
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", {
|
||||
class: normalizeClass(["font-sans w-screen h-screen flex items-center justify-center pointer-events-auto overflow-auto", [
|
||||
props.dark ? "text-neutral-300 bg-neutral-900 dark-theme" : "text-neutral-900 bg-neutral-300"
|
||||
]])
|
||||
}, [
|
||||
renderSlot(_ctx.$slots, "default")
|
||||
], 2);
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as _
|
||||
};
|
||||
//# sourceMappingURL=BaseViewTemplate-BNGF4K22.js.map
|
||||
51
web/assets/BaseViewTemplate-BTbuZf5t.js
generated
vendored
Normal file
51
web/assets/BaseViewTemplate-BTbuZf5t.js
generated
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
import { d as defineComponent, T as ref, p as onMounted, b8 as isElectron, V as nextTick, b9 as electronAPI, o as openBlock, f as createElementBlock, i as withDirectives, v as vShow, j as unref, ba as isNativeWindow, m as createBaseVNode, A as renderSlot, aj as normalizeClass } from "./index-Bv0b06LE.js";
|
||||
const _hoisted_1 = { class: "flex-grow w-full flex items-center justify-center overflow-auto" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "BaseViewTemplate",
|
||||
props: {
|
||||
dark: { type: Boolean, default: false }
|
||||
},
|
||||
setup(__props) {
|
||||
const props = __props;
|
||||
const darkTheme = {
|
||||
color: "rgba(0, 0, 0, 0)",
|
||||
symbolColor: "#d4d4d4"
|
||||
};
|
||||
const lightTheme = {
|
||||
color: "rgba(0, 0, 0, 0)",
|
||||
symbolColor: "#171717"
|
||||
};
|
||||
const topMenuRef = ref(null);
|
||||
onMounted(async () => {
|
||||
if (isElectron()) {
|
||||
await nextTick();
|
||||
electronAPI().changeTheme({
|
||||
...props.dark ? darkTheme : lightTheme,
|
||||
height: topMenuRef.value.getBoundingClientRect().height
|
||||
});
|
||||
}
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", {
|
||||
class: normalizeClass(["font-sans w-screen h-screen flex flex-col", [
|
||||
props.dark ? "text-neutral-300 bg-neutral-900 dark-theme" : "text-neutral-900 bg-neutral-300"
|
||||
]])
|
||||
}, [
|
||||
withDirectives(createBaseVNode("div", {
|
||||
ref_key: "topMenuRef",
|
||||
ref: topMenuRef,
|
||||
class: "app-drag w-full h-[var(--comfy-topbar-height)]"
|
||||
}, null, 512), [
|
||||
[vShow, unref(isNativeWindow)()]
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
renderSlot(_ctx.$slots, "default")
|
||||
])
|
||||
], 2);
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as _
|
||||
};
|
||||
//# sourceMappingURL=BaseViewTemplate-BTbuZf5t.js.map
|
||||
19
web/assets/DesktopStartView-D9r53Bue.js
generated
vendored
Normal file
19
web/assets/DesktopStartView-D9r53Bue.js
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
import { d as defineComponent, o as openBlock, y as createBlock, z as withCtx, k as createVNode, j as unref, bE as script } from "./index-Bv0b06LE.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BTbuZf5t.js";
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "DesktopStartView",
|
||||
setup(__props) {
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createBlock(_sfc_main$1, { dark: "" }, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script), { class: "m-8 w-48 h-48" })
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=DesktopStartView-D9r53Bue.js.map
|
||||
58
web/assets/DesktopUpdateView-C-R0415K.js
generated
vendored
Normal file
58
web/assets/DesktopUpdateView-C-R0415K.js
generated
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, T as ref, d8 as onUnmounted, o as openBlock, y as createBlock, z as withCtx, m as createBaseVNode, E as toDisplayString, j as unref, bg as t, k as createVNode, bE as script, l as script$1, b9 as electronAPI, _ as _export_sfc } from "./index-Bv0b06LE.js";
|
||||
import { s as script$2 } from "./index-A_bXPJCN.js";
|
||||
import { _ as _sfc_main$1 } from "./TerminalOutputDrawer-CKr7Br7O.js";
|
||||
import { _ as _sfc_main$2 } from "./BaseViewTemplate-BTbuZf5t.js";
|
||||
const _hoisted_1 = { class: "h-screen w-screen grid items-center justify-around overflow-y-auto" };
|
||||
const _hoisted_2 = { class: "relative m-8 text-center" };
|
||||
const _hoisted_3 = { class: "download-bg pi-download text-4xl font-bold" };
|
||||
const _hoisted_4 = { class: "m-8" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "DesktopUpdateView",
|
||||
setup(__props) {
|
||||
const electron = electronAPI();
|
||||
const terminalVisible = ref(false);
|
||||
const toggleConsoleDrawer = /* @__PURE__ */ __name(() => {
|
||||
terminalVisible.value = !terminalVisible.value;
|
||||
}, "toggleConsoleDrawer");
|
||||
onUnmounted(() => electron.Validation.dispose());
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createBlock(_sfc_main$2, { dark: "" }, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createBaseVNode("h1", _hoisted_3, toDisplayString(unref(t)("desktopUpdate.title")), 1),
|
||||
createBaseVNode("div", _hoisted_4, [
|
||||
createBaseVNode("span", null, toDisplayString(unref(t)("desktopUpdate.description")), 1)
|
||||
]),
|
||||
createVNode(unref(script), { class: "m-8 w-48 h-48" }),
|
||||
createVNode(unref(script$1), {
|
||||
style: { "transform": "translateX(-50%)" },
|
||||
class: "fixed bottom-0 left-1/2 my-8",
|
||||
label: unref(t)("maintenance.consoleLogs"),
|
||||
icon: "pi pi-desktop",
|
||||
"icon-pos": "left",
|
||||
severity: "secondary",
|
||||
onClick: toggleConsoleDrawer
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(_sfc_main$1, {
|
||||
modelValue: terminalVisible.value,
|
||||
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => terminalVisible.value = $event),
|
||||
header: unref(t)("g.terminal"),
|
||||
"default-message": unref(t)("desktopUpdate.terminalDefaultMessage")
|
||||
}, null, 8, ["modelValue", "header", "default-message"])
|
||||
])
|
||||
]),
|
||||
createVNode(unref(script$2))
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
const DesktopUpdateView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-8d77828d"]]);
|
||||
export {
|
||||
DesktopUpdateView as default
|
||||
};
|
||||
//# sourceMappingURL=DesktopUpdateView-C-R0415K.js.map
|
||||
20
web/assets/DesktopUpdateView-CxchaIvw.css
generated
vendored
Normal file
20
web/assets/DesktopUpdateView-CxchaIvw.css
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
|
||||
.download-bg[data-v-8d77828d]::before {
|
||||
position: absolute;
|
||||
margin: 0px;
|
||||
color: var(--p-text-muted-color);
|
||||
font-family: 'primeicons';
|
||||
top: -2rem;
|
||||
right: 2rem;
|
||||
speak: none;
|
||||
font-style: normal;
|
||||
font-weight: normal;
|
||||
font-variant: normal;
|
||||
text-transform: none;
|
||||
line-height: 1;
|
||||
display: inline-block;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
opacity: 0.02;
|
||||
font-size: min(14rem, 90vw);
|
||||
z-index: 0
|
||||
}
|
||||
6
web/assets/DownloadGitView-DeC7MBzG.js → web/assets/DownloadGitView-PWqK5ke4.js
generated
vendored
6
web/assets/DownloadGitView-DeC7MBzG.js → web/assets/DownloadGitView-PWqK5ke4.js
generated
vendored
@@ -1,7 +1,7 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, o as openBlock, k as createBlock, M as withCtx, H as createBaseVNode, X as toDisplayString, N as createVNode, j as unref, l as script, bW as useRouter } from "./index-DjNHn37O.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BNGF4K22.js";
|
||||
import { d as defineComponent, o as openBlock, y as createBlock, z as withCtx, m as createBaseVNode, E as toDisplayString, k as createVNode, j as unref, l as script, bi as useRouter } from "./index-Bv0b06LE.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BTbuZf5t.js";
|
||||
const _hoisted_1 = { class: "max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding" };
|
||||
const _hoisted_2 = { class: "mt-24 text-4xl font-bold text-red-500" };
|
||||
const _hoisted_3 = { class: "space-y-4" };
|
||||
@@ -55,4 +55,4 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=DownloadGitView-DeC7MBzG.js.map
|
||||
//# sourceMappingURL=DownloadGitView-PWqK5ke4.js.map
|
||||
9
web/assets/ExtensionPanel-D4Phn0Zr.js → web/assets/ExtensionPanel-Ba57xrmg.js
generated
vendored
9
web/assets/ExtensionPanel-D4Phn0Zr.js → web/assets/ExtensionPanel-Ba57xrmg.js
generated
vendored
@@ -1,9 +1,8 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, ab as ref, cn as FilterMatchMode, cs as useExtensionStore, a as useSettingStore, m as onMounted, c as computed, o as openBlock, k as createBlock, M as withCtx, N as createVNode, co as SearchBox, j as unref, bZ as script, H as createBaseVNode, f as createElementBlock, E as renderList, X as toDisplayString, aE as createTextVNode, F as Fragment, l as script$1, I as createCommentVNode, aI as script$3, bO as script$4, c4 as script$5, cp as _sfc_main$1 } from "./index-DjNHn37O.js";
|
||||
import { s as script$2, a as script$6 } from "./index-B5F0uxTQ.js";
|
||||
import "./index-B-aVupP5.js";
|
||||
import "./index-5HFeZax4.js";
|
||||
import { d as defineComponent, T as ref, dx as FilterMatchMode, dC as useExtensionStore, a as useSettingStore, p as onMounted, c as computed, o as openBlock, y as createBlock, z as withCtx, k as createVNode, dy as SearchBox, j as unref, bn as script, m as createBaseVNode, f as createElementBlock, D as renderList, E as toDisplayString, a8 as createTextVNode, F as Fragment, l as script$1, B as createCommentVNode, a5 as script$3, ay as script$4, br as script$5, dz as _sfc_main$1 } from "./index-Bv0b06LE.js";
|
||||
import { g as script$2, h as script$6 } from "./index-CgMyWf7n.js";
|
||||
import "./index-Dzu9WL4p.js";
|
||||
const _hoisted_1 = { class: "flex justify-end" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ExtensionPanel",
|
||||
@@ -180,4 +179,4 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ExtensionPanel-D4Phn0Zr.js.map
|
||||
//# sourceMappingURL=ExtensionPanel-Ba57xrmg.js.map
|
||||
4919
web/assets/GraphView-B_UDZi95.js
generated
vendored
Normal file
4919
web/assets/GraphView-B_UDZi95.js
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
383
web/assets/GraphView-Bo28XDd0.css
generated
vendored
Normal file
383
web/assets/GraphView-Bo28XDd0.css
generated
vendored
Normal file
@@ -0,0 +1,383 @@
|
||||
|
||||
.comfy-menu-hamburger[data-v-82120b51] {
|
||||
position: fixed;
|
||||
z-index: 9999;
|
||||
display: flex;
|
||||
flex-direction: row
|
||||
}
|
||||
|
||||
[data-v-e50caa15] .p-splitter-gutter {
|
||||
pointer-events: auto;
|
||||
}
|
||||
[data-v-e50caa15] .p-splitter-gutter:hover,[data-v-e50caa15] .p-splitter-gutter[data-p-gutter-resizing='true'] {
|
||||
transition: background-color 0.2s ease 300ms;
|
||||
background-color: var(--p-primary-color);
|
||||
}
|
||||
.side-bar-panel[data-v-e50caa15] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.bottom-panel[data-v-e50caa15] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.splitter-overlay[data-v-e50caa15] {
|
||||
pointer-events: none;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
}
|
||||
.splitter-overlay-root[data-v-e50caa15] {
|
||||
position: absolute;
|
||||
top: 0px;
|
||||
left: 0px;
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
|
||||
/* Set it the same as the ComfyUI menu */
|
||||
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
|
||||
999 should be sufficient to make sure splitter overlays on node's DOM
|
||||
widgets */
|
||||
z-index: 999;
|
||||
}
|
||||
|
||||
.p-buttongroup-vertical[data-v-27a9500c] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
border-radius: var(--p-button-border-radius);
|
||||
overflow: hidden;
|
||||
border: 1px solid var(--p-panel-border-color);
|
||||
}
|
||||
.p-buttongroup-vertical .p-button[data-v-27a9500c] {
|
||||
margin: 0;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.node-tooltip[data-v-f03142eb] {
|
||||
background: var(--comfy-input-bg);
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||
color: var(--input-text);
|
||||
font-family: sans-serif;
|
||||
left: 0;
|
||||
max-width: 30vw;
|
||||
padding: 4px 8px;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
transform: translate(5px, calc(-100% - 5px));
|
||||
white-space: pre-wrap;
|
||||
z-index: 99999;
|
||||
}
|
||||
|
||||
.group-title-editor.node-title-editor[data-v-12d3fd12] {
|
||||
z-index: 9999;
|
||||
padding: 0.25rem;
|
||||
}
|
||||
[data-v-12d3fd12] .editable-text {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
[data-v-12d3fd12] .editable-text input {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
/* Override the default font size */
|
||||
font-size: inherit;
|
||||
}
|
||||
|
||||
[data-v-fd0a74bd] .highlight {
|
||||
background-color: var(--p-primary-color);
|
||||
color: var(--p-primary-contrast-color);
|
||||
font-weight: bold;
|
||||
border-radius: 0.25rem;
|
||||
padding: 0rem 0.125rem;
|
||||
margin: -0.125rem 0.125rem;
|
||||
}
|
||||
|
||||
.invisible-dialog-root {
|
||||
width: 60%;
|
||||
min-width: 24rem;
|
||||
max-width: 48rem;
|
||||
border: 0 !important;
|
||||
background-color: transparent !important;
|
||||
margin-top: 25vh;
|
||||
margin-left: 400px;
|
||||
}
|
||||
@media all and (max-width: 768px) {
|
||||
.invisible-dialog-root {
|
||||
margin-left: 0px;
|
||||
}
|
||||
}
|
||||
.node-search-box-dialog-mask {
|
||||
align-items: flex-start !important;
|
||||
}
|
||||
|
||||
.side-bar-button-icon {
|
||||
font-size: var(--sidebar-icon-size) !important;
|
||||
}
|
||||
.side-bar-button-selected .side-bar-button-icon {
|
||||
font-size: var(--sidebar-icon-size) !important;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.side-bar-button[data-v-6ab4daa6] {
|
||||
width: var(--sidebar-width);
|
||||
height: var(--sidebar-width);
|
||||
border-radius: 0;
|
||||
}
|
||||
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-6ab4daa6],
|
||||
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-6ab4daa6]:hover {
|
||||
border-left: 4px solid var(--p-button-text-primary-color);
|
||||
}
|
||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-6ab4daa6],
|
||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-6ab4daa6]:hover {
|
||||
border-right: 4px solid var(--p-button-text-primary-color);
|
||||
}
|
||||
|
||||
.side-tool-bar-container[data-v-04875455] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
|
||||
width: var(--sidebar-width);
|
||||
height: 100%;
|
||||
|
||||
background-color: var(--comfy-menu-secondary-bg);
|
||||
color: var(--fg-color);
|
||||
box-shadow: var(--bar-shadow);
|
||||
|
||||
--sidebar-width: 4rem;
|
||||
--sidebar-icon-size: 1.5rem;
|
||||
}
|
||||
.side-tool-bar-container.small-sidebar[data-v-04875455] {
|
||||
--sidebar-width: 2.5rem;
|
||||
--sidebar-icon-size: 1rem;
|
||||
}
|
||||
.side-tool-bar-end[data-v-04875455] {
|
||||
align-self: flex-end;
|
||||
margin-top: auto;
|
||||
}
|
||||
|
||||
.status-indicator[data-v-fd6ae3af] {
|
||||
position: absolute;
|
||||
font-weight: 700;
|
||||
font-size: 1.5rem;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%)
|
||||
}
|
||||
|
||||
[data-v-54fadc45] .p-togglebutton {
|
||||
position: relative;
|
||||
flex-shrink: 0;
|
||||
border-radius: 0px;
|
||||
border-width: 0px;
|
||||
border-right-width: 1px;
|
||||
border-style: solid;
|
||||
background-color: transparent;
|
||||
padding: 0px;
|
||||
border-right-color: var(--border-color)
|
||||
}
|
||||
[data-v-54fadc45] .p-togglebutton::before {
|
||||
display: none
|
||||
}
|
||||
[data-v-54fadc45] .p-togglebutton:first-child {
|
||||
border-left-width: 1px;
|
||||
border-style: solid;
|
||||
border-left-color: var(--border-color)
|
||||
}
|
||||
[data-v-54fadc45] .p-togglebutton:not(:first-child) {
|
||||
border-left-width: 0px
|
||||
}
|
||||
[data-v-54fadc45] .p-togglebutton.p-togglebutton-checked {
|
||||
height: 100%;
|
||||
border-bottom-width: 1px;
|
||||
border-style: solid;
|
||||
border-bottom-color: var(--p-button-text-primary-color)
|
||||
}
|
||||
[data-v-54fadc45] .p-togglebutton:not(.p-togglebutton-checked) {
|
||||
opacity: 0.75
|
||||
}
|
||||
[data-v-54fadc45] .p-togglebutton-checked .close-button,[data-v-54fadc45] .p-togglebutton:hover .close-button {
|
||||
visibility: visible
|
||||
}
|
||||
[data-v-54fadc45] .p-togglebutton:hover .status-indicator {
|
||||
display: none
|
||||
}
|
||||
[data-v-54fadc45] .p-togglebutton .close-button {
|
||||
visibility: hidden
|
||||
}
|
||||
[data-v-54fadc45] .p-scrollpanel-content {
|
||||
height: 100%
|
||||
}
|
||||
|
||||
/* Scrollbar half opacity to avoid blocking the active tab bottom border */
|
||||
[data-v-54fadc45] .p-scrollpanel:hover .p-scrollpanel-bar,[data-v-54fadc45] .p-scrollpanel:active .p-scrollpanel-bar {
|
||||
opacity: 0.5
|
||||
}
|
||||
[data-v-54fadc45] .p-selectbutton {
|
||||
height: 100%;
|
||||
border-radius: 0px
|
||||
}
|
||||
|
||||
[data-v-6ab68035] .workflow-tabs {
|
||||
background-color: var(--comfy-menu-bg);
|
||||
}
|
||||
|
||||
[data-v-26957f1f] .p-inputtext {
|
||||
border-top-left-radius: 0;
|
||||
border-bottom-left-radius: 0;
|
||||
}
|
||||
|
||||
.comfyui-queue-button[data-v-91a628af] .p-splitbutton-dropdown {
|
||||
border-top-right-radius: 0;
|
||||
border-bottom-right-radius: 0;
|
||||
}
|
||||
|
||||
.actionbar[data-v-ebd56d51] {
|
||||
pointer-events: all;
|
||||
position: fixed;
|
||||
z-index: 1000;
|
||||
}
|
||||
.actionbar.is-docked[data-v-ebd56d51] {
|
||||
position: static;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
padding: 0px;
|
||||
}
|
||||
.actionbar.is-dragging[data-v-ebd56d51] {
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
[data-v-ebd56d51] .p-panel-content {
|
||||
padding: 0.25rem;
|
||||
}
|
||||
.is-docked[data-v-ebd56d51] .p-panel-content {
|
||||
padding: 0px;
|
||||
}
|
||||
[data-v-ebd56d51] .p-panel-header {
|
||||
display: none;
|
||||
}
|
||||
.drag-handle[data-v-ebd56d51] {
|
||||
height: -moz-max-content;
|
||||
height: max-content;
|
||||
width: 0.75rem;
|
||||
}
|
||||
|
||||
.top-menubar[data-v-56df69d2] .p-menubar-item-link svg {
|
||||
display: none;
|
||||
}
|
||||
[data-v-56df69d2] .p-menubar-submenu.dropdown-direction-up {
|
||||
top: auto;
|
||||
bottom: 100%;
|
||||
flex-direction: column-reverse;
|
||||
}
|
||||
.keybinding-tag[data-v-56df69d2] {
|
||||
background: var(--p-content-hover-background);
|
||||
border-color: var(--p-content-border-color);
|
||||
border-style: solid;
|
||||
}
|
||||
|
||||
.comfyui-menu[data-v-68d3b5b9] {
|
||||
width: 100vw;
|
||||
height: var(--comfy-topbar-height);
|
||||
background: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
box-shadow: var(--bar-shadow);
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
font-size: 0.8em;
|
||||
box-sizing: border-box;
|
||||
z-index: 1000;
|
||||
order: 0;
|
||||
grid-column: 1/-1;
|
||||
}
|
||||
.comfyui-menu.dropzone[data-v-68d3b5b9] {
|
||||
background: var(--p-highlight-background);
|
||||
}
|
||||
.comfyui-menu.dropzone-active[data-v-68d3b5b9] {
|
||||
background: var(--p-highlight-background-focus);
|
||||
}
|
||||
[data-v-68d3b5b9] .p-menubar-item-label {
|
||||
line-height: revert;
|
||||
}
|
||||
.comfyui-logo[data-v-68d3b5b9] {
|
||||
font-size: 1.2em;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
user-select: none;
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.comfyui-body[data-v-e89d9273] {
|
||||
grid-template-columns: auto 1fr auto;
|
||||
grid-template-rows: auto 1fr auto;
|
||||
}
|
||||
|
||||
/**
|
||||
+------------------+------------------+------------------+
|
||||
| |
|
||||
| .comfyui-body- |
|
||||
| top |
|
||||
| (spans all cols) |
|
||||
| |
|
||||
+------------------+------------------+------------------+
|
||||
| | | |
|
||||
| .comfyui-body- | #graph-canvas | .comfyui-body- |
|
||||
| left | | right |
|
||||
| | | |
|
||||
| | | |
|
||||
+------------------+------------------+------------------+
|
||||
| |
|
||||
| .comfyui-body- |
|
||||
| bottom |
|
||||
| (spans all cols) |
|
||||
| |
|
||||
+------------------+------------------+------------------+
|
||||
*/
|
||||
.comfyui-body-top[data-v-e89d9273] {
|
||||
order: -5;
|
||||
/* Span across all columns */
|
||||
grid-column: 1/-1;
|
||||
/* Position at the first row */
|
||||
grid-row: 1;
|
||||
/* Top menu bar dropdown needs to be above of graph canvas splitter overlay which is z-index: 999 */
|
||||
/* Top menu bar z-index needs to be higher than bottom menu bar z-index as by default
|
||||
pysssss's image feed is located at body-bottom, and it can overlap with the queue button, which
|
||||
is located in body-top. */
|
||||
z-index: 1001;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.comfyui-body-left[data-v-e89d9273] {
|
||||
order: -4;
|
||||
/* Position in the first column */
|
||||
grid-column: 1;
|
||||
/* Position below the top element */
|
||||
grid-row: 2;
|
||||
z-index: 10;
|
||||
display: flex;
|
||||
}
|
||||
.graph-canvas-container[data-v-e89d9273] {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
order: -3;
|
||||
grid-column: 2;
|
||||
grid-row: 2;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
.comfyui-body-right[data-v-e89d9273] {
|
||||
order: -2;
|
||||
z-index: 10;
|
||||
grid-column: 3;
|
||||
grid-row: 2;
|
||||
}
|
||||
.comfyui-body-bottom[data-v-e89d9273] {
|
||||
order: 4;
|
||||
/* Span across all columns */
|
||||
grid-column: 1/-1;
|
||||
grid-row: 3;
|
||||
/* Bottom menu bar dropdown needs to be above of graph canvas splitter overlay which is z-index: 999 */
|
||||
z-index: 1000;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
273
web/assets/GraphView-CIRWBKTm.css
generated
vendored
273
web/assets/GraphView-CIRWBKTm.css
generated
vendored
@@ -1,273 +0,0 @@
|
||||
|
||||
.comfy-menu-hamburger[data-v-5661bed0] {
|
||||
pointer-events: auto;
|
||||
position: fixed;
|
||||
z-index: 9999;
|
||||
}
|
||||
|
||||
[data-v-e50caa15] .p-splitter-gutter {
|
||||
pointer-events: auto;
|
||||
}
|
||||
[data-v-e50caa15] .p-splitter-gutter:hover,[data-v-e50caa15] .p-splitter-gutter[data-p-gutter-resizing='true'] {
|
||||
transition: background-color 0.2s ease 300ms;
|
||||
background-color: var(--p-primary-color);
|
||||
}
|
||||
.side-bar-panel[data-v-e50caa15] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.bottom-panel[data-v-e50caa15] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.splitter-overlay[data-v-e50caa15] {
|
||||
pointer-events: none;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
}
|
||||
.splitter-overlay-root[data-v-e50caa15] {
|
||||
position: absolute;
|
||||
top: 0px;
|
||||
left: 0px;
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
|
||||
/* Set it the same as the ComfyUI menu */
|
||||
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
|
||||
999 should be sufficient to make sure splitter overlays on node's DOM
|
||||
widgets */
|
||||
z-index: 999;
|
||||
}
|
||||
|
||||
.p-buttongroup-vertical[data-v-cf40dd39] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
border-radius: var(--p-button-border-radius);
|
||||
overflow: hidden;
|
||||
border: 1px solid var(--p-panel-border-color);
|
||||
}
|
||||
.p-buttongroup-vertical .p-button[data-v-cf40dd39] {
|
||||
margin: 0;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.node-tooltip[data-v-46859edf] {
|
||||
background: var(--comfy-input-bg);
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||
color: var(--input-text);
|
||||
font-family: sans-serif;
|
||||
left: 0;
|
||||
max-width: 30vw;
|
||||
padding: 4px 8px;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
transform: translate(5px, calc(-100% - 5px));
|
||||
white-space: pre-wrap;
|
||||
z-index: 99999;
|
||||
}
|
||||
|
||||
.group-title-editor.node-title-editor[data-v-12d3fd12] {
|
||||
z-index: 9999;
|
||||
padding: 0.25rem;
|
||||
}
|
||||
[data-v-12d3fd12] .editable-text {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
[data-v-12d3fd12] .editable-text input {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
/* Override the default font size */
|
||||
font-size: inherit;
|
||||
}
|
||||
|
||||
[data-v-5741c9ae] .highlight {
|
||||
background-color: var(--p-primary-color);
|
||||
color: var(--p-primary-contrast-color);
|
||||
font-weight: bold;
|
||||
border-radius: 0.25rem;
|
||||
padding: 0rem 0.125rem;
|
||||
margin: -0.125rem 0.125rem;
|
||||
}
|
||||
|
||||
.invisible-dialog-root {
|
||||
width: 60%;
|
||||
min-width: 24rem;
|
||||
max-width: 48rem;
|
||||
border: 0 !important;
|
||||
background-color: transparent !important;
|
||||
margin-top: 25vh;
|
||||
margin-left: 400px;
|
||||
}
|
||||
@media all and (max-width: 768px) {
|
||||
.invisible-dialog-root {
|
||||
margin-left: 0px;
|
||||
}
|
||||
}
|
||||
.node-search-box-dialog-mask {
|
||||
align-items: flex-start !important;
|
||||
}
|
||||
|
||||
.side-bar-button-icon {
|
||||
font-size: var(--sidebar-icon-size) !important;
|
||||
}
|
||||
.side-bar-button-selected .side-bar-button-icon {
|
||||
font-size: var(--sidebar-icon-size) !important;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.side-bar-button[data-v-6ab4daa6] {
|
||||
width: var(--sidebar-width);
|
||||
height: var(--sidebar-width);
|
||||
border-radius: 0;
|
||||
}
|
||||
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-6ab4daa6],
|
||||
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-6ab4daa6]:hover {
|
||||
border-left: 4px solid var(--p-button-text-primary-color);
|
||||
}
|
||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-6ab4daa6],
|
||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-6ab4daa6]:hover {
|
||||
border-right: 4px solid var(--p-button-text-primary-color);
|
||||
}
|
||||
|
||||
:root {
|
||||
--sidebar-width: 64px;
|
||||
--sidebar-icon-size: 1.5rem;
|
||||
}
|
||||
:root .small-sidebar {
|
||||
--sidebar-width: 40px;
|
||||
--sidebar-icon-size: 1rem;
|
||||
}
|
||||
|
||||
.side-tool-bar-container[data-v-37d8d7b4] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
|
||||
pointer-events: auto;
|
||||
|
||||
width: var(--sidebar-width);
|
||||
height: 100%;
|
||||
|
||||
background-color: var(--comfy-menu-secondary-bg);
|
||||
color: var(--fg-color);
|
||||
box-shadow: var(--bar-shadow);
|
||||
}
|
||||
.side-tool-bar-end[data-v-37d8d7b4] {
|
||||
align-self: flex-end;
|
||||
margin-top: auto;
|
||||
}
|
||||
|
||||
[data-v-b9328350] .p-inputtext {
|
||||
border-top-left-radius: 0;
|
||||
border-bottom-left-radius: 0;
|
||||
}
|
||||
|
||||
.comfyui-queue-button[data-v-7f4f551b] .p-splitbutton-dropdown {
|
||||
border-top-right-radius: 0;
|
||||
border-bottom-right-radius: 0;
|
||||
}
|
||||
|
||||
.actionbar[data-v-915e5456] {
|
||||
pointer-events: all;
|
||||
position: fixed;
|
||||
z-index: 1000;
|
||||
}
|
||||
.actionbar.is-docked[data-v-915e5456] {
|
||||
position: static;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
padding: 0px;
|
||||
}
|
||||
.actionbar.is-dragging[data-v-915e5456] {
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
[data-v-915e5456] .p-panel-content {
|
||||
padding: 0.25rem;
|
||||
}
|
||||
.is-docked[data-v-915e5456] .p-panel-content {
|
||||
padding: 0px;
|
||||
}
|
||||
[data-v-915e5456] .p-panel-header {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.top-menubar[data-v-6fecd137] .p-menubar-item-link svg {
|
||||
display: none;
|
||||
}
|
||||
[data-v-6fecd137] .p-menubar-submenu.dropdown-direction-up {
|
||||
top: auto;
|
||||
bottom: 100%;
|
||||
flex-direction: column-reverse;
|
||||
}
|
||||
.keybinding-tag[data-v-6fecd137] {
|
||||
background: var(--p-content-hover-background);
|
||||
border-color: var(--p-content-border-color);
|
||||
border-style: solid;
|
||||
}
|
||||
|
||||
.status-indicator[data-v-8d011a31] {
|
||||
position: absolute;
|
||||
font-weight: 700;
|
||||
font-size: 1.5rem;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%)
|
||||
}
|
||||
|
||||
[data-v-d485c044] .p-togglebutton::before {
|
||||
display: none
|
||||
}
|
||||
[data-v-d485c044] .p-togglebutton {
|
||||
position: relative;
|
||||
flex-shrink: 0;
|
||||
border-radius: 0px;
|
||||
background-color: transparent;
|
||||
padding: 0px
|
||||
}
|
||||
[data-v-d485c044] .p-togglebutton.p-togglebutton-checked {
|
||||
border-bottom-width: 2px;
|
||||
border-bottom-color: var(--p-button-text-primary-color)
|
||||
}
|
||||
[data-v-d485c044] .p-togglebutton-checked .close-button,[data-v-d485c044] .p-togglebutton:hover .close-button {
|
||||
visibility: visible
|
||||
}
|
||||
[data-v-d485c044] .p-togglebutton:hover .status-indicator {
|
||||
display: none
|
||||
}
|
||||
[data-v-d485c044] .p-togglebutton .close-button {
|
||||
visibility: hidden
|
||||
}
|
||||
|
||||
.comfyui-menu[data-v-878b63b8] {
|
||||
width: 100vw;
|
||||
background: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
box-shadow: var(--bar-shadow);
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
font-size: 0.8em;
|
||||
box-sizing: border-box;
|
||||
z-index: 1000;
|
||||
order: 0;
|
||||
grid-column: 1/-1;
|
||||
max-height: 90vh;
|
||||
}
|
||||
.comfyui-menu.dropzone[data-v-878b63b8] {
|
||||
background: var(--p-highlight-background);
|
||||
}
|
||||
.comfyui-menu.dropzone-active[data-v-878b63b8] {
|
||||
background: var(--p-highlight-background-focus);
|
||||
}
|
||||
[data-v-878b63b8] .p-menubar-item-label {
|
||||
line-height: revert;
|
||||
}
|
||||
.comfyui-logo[data-v-878b63b8] {
|
||||
font-size: 1.2em;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
user-select: none;
|
||||
cursor: default;
|
||||
}
|
||||
9894
web/assets/GraphView-HVeNbkaW.js
generated
vendored
9894
web/assets/GraphView-HVeNbkaW.js
generated
vendored
File diff suppressed because one or more lines are too long
1288
web/assets/InstallView-CAcYt0HL.js
generated
vendored
1288
web/assets/InstallView-CAcYt0HL.js
generated
vendored
File diff suppressed because one or more lines are too long
971
web/assets/InstallView-DW9xwU_F.js
generated
vendored
Normal file
971
web/assets/InstallView-DW9xwU_F.js
generated
vendored
Normal file
@@ -0,0 +1,971 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, T as ref, bq as useModel, o as openBlock, f as createElementBlock, m as createBaseVNode, E as toDisplayString, k as createVNode, j as unref, br as script, bl as script$1, as as withModifiers, z as withCtx, ac as script$2, I as useI18n, c as computed, aj as normalizeClass, B as createCommentVNode, a5 as script$3, a8 as createTextVNode, b9 as electronAPI, _ as _export_sfc, p as onMounted, r as resolveDirective, bk as script$4, i as withDirectives, bs as script$5, bt as script$6, l as script$7, y as createBlock, bn as script$8, bu as MigrationItems, w as watchEffect, F as Fragment, D as renderList, bv as script$9, bw as mergeModels, bx as ValidationState, X as normalizeI18nKey, N as watch, by as checkMirrorReachable, bz as _sfc_main$7, bA as isInChina, bB as mergeValidationStates, bg as t, b3 as script$a, bC as CUDA_TORCH_URL, bD as NIGHTLY_CPU_TORCH_URL, bi as useRouter, ah as toRaw } from "./index-Bv0b06LE.js";
|
||||
import { s as script$b, a as script$c, b as script$d, c as script$e, d as script$f } from "./index-SeIZOWJp.js";
|
||||
import { P as PYTHON_MIRROR, a as PYPI_MIRROR } from "./uvMirrors-B-HKMf6X.js";
|
||||
import { _ as _sfc_main$8 } from "./BaseViewTemplate-BTbuZf5t.js";
|
||||
const _hoisted_1$5 = { class: "flex flex-col gap-6 w-[600px]" };
|
||||
const _hoisted_2$5 = { class: "flex flex-col gap-4" };
|
||||
const _hoisted_3$5 = { class: "text-2xl font-semibold text-neutral-100" };
|
||||
const _hoisted_4$5 = { class: "text-neutral-400 my-0" };
|
||||
const _hoisted_5$3 = { class: "flex flex-col bg-neutral-800 p-4 rounded-lg" };
|
||||
const _hoisted_6$3 = { class: "flex items-center gap-4" };
|
||||
const _hoisted_7$3 = { class: "flex-1" };
|
||||
const _hoisted_8$3 = { class: "text-lg font-medium text-neutral-100" };
|
||||
const _hoisted_9$3 = { class: "text-sm text-neutral-400 mt-1" };
|
||||
const _hoisted_10$3 = { class: "flex items-center gap-4" };
|
||||
const _hoisted_11$3 = { class: "flex-1" };
|
||||
const _hoisted_12$3 = { class: "text-lg font-medium text-neutral-100" };
|
||||
const _hoisted_13$1 = { class: "text-sm text-neutral-400 mt-1" };
|
||||
const _hoisted_14$1 = { class: "text-neutral-300" };
|
||||
const _hoisted_15 = { class: "font-medium mb-2" };
|
||||
const _hoisted_16 = { class: "list-disc pl-6 space-y-1" };
|
||||
const _hoisted_17 = { class: "font-medium mt-4 mb-2" };
|
||||
const _hoisted_18 = { class: "list-disc pl-6 space-y-1" };
|
||||
const _hoisted_19 = { class: "mt-4" };
|
||||
const _hoisted_20 = {
|
||||
href: "https://comfy.org/privacy",
|
||||
target: "_blank",
|
||||
class: "text-blue-400 hover:text-blue-300 underline"
|
||||
};
|
||||
const _sfc_main$6 = /* @__PURE__ */ defineComponent({
|
||||
__name: "DesktopSettingsConfiguration",
|
||||
props: {
|
||||
"autoUpdate": { type: Boolean, ...{ required: true } },
|
||||
"autoUpdateModifiers": {},
|
||||
"allowMetrics": { type: Boolean, ...{ required: true } },
|
||||
"allowMetricsModifiers": {}
|
||||
},
|
||||
emits: ["update:autoUpdate", "update:allowMetrics"],
|
||||
setup(__props) {
|
||||
const showDialog = ref(false);
|
||||
const autoUpdate = useModel(__props, "autoUpdate");
|
||||
const allowMetrics = useModel(__props, "allowMetrics");
|
||||
const showMetricsInfo = /* @__PURE__ */ __name(() => {
|
||||
showDialog.value = true;
|
||||
}, "showMetricsInfo");
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1$5, [
|
||||
createBaseVNode("div", _hoisted_2$5, [
|
||||
createBaseVNode("h2", _hoisted_3$5, toDisplayString(_ctx.$t("install.desktopAppSettings")), 1),
|
||||
createBaseVNode("p", _hoisted_4$5, toDisplayString(_ctx.$t("install.desktopAppSettingsDescription")), 1)
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_5$3, [
|
||||
createBaseVNode("div", _hoisted_6$3, [
|
||||
createBaseVNode("div", _hoisted_7$3, [
|
||||
createBaseVNode("h3", _hoisted_8$3, toDisplayString(_ctx.$t("install.settings.autoUpdate")), 1),
|
||||
createBaseVNode("p", _hoisted_9$3, toDisplayString(_ctx.$t("install.settings.autoUpdateDescription")), 1)
|
||||
]),
|
||||
createVNode(unref(script), {
|
||||
modelValue: autoUpdate.value,
|
||||
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => autoUpdate.value = $event)
|
||||
}, null, 8, ["modelValue"])
|
||||
]),
|
||||
createVNode(unref(script$1)),
|
||||
createBaseVNode("div", _hoisted_10$3, [
|
||||
createBaseVNode("div", _hoisted_11$3, [
|
||||
createBaseVNode("h3", _hoisted_12$3, toDisplayString(_ctx.$t("install.settings.allowMetrics")), 1),
|
||||
createBaseVNode("p", _hoisted_13$1, toDisplayString(_ctx.$t("install.settings.allowMetricsDescription")), 1),
|
||||
createBaseVNode("a", {
|
||||
href: "#",
|
||||
class: "text-sm text-blue-400 hover:text-blue-300 mt-1 inline-block",
|
||||
onClick: withModifiers(showMetricsInfo, ["prevent"])
|
||||
}, toDisplayString(_ctx.$t("install.settings.learnMoreAboutData")), 1)
|
||||
]),
|
||||
createVNode(unref(script), {
|
||||
modelValue: allowMetrics.value,
|
||||
"onUpdate:modelValue": _cache[1] || (_cache[1] = ($event) => allowMetrics.value = $event)
|
||||
}, null, 8, ["modelValue"])
|
||||
])
|
||||
]),
|
||||
createVNode(unref(script$2), {
|
||||
visible: showDialog.value,
|
||||
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => showDialog.value = $event),
|
||||
modal: "",
|
||||
header: _ctx.$t("install.settings.dataCollectionDialog.title")
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", _hoisted_14$1, [
|
||||
createBaseVNode("h4", _hoisted_15, toDisplayString(_ctx.$t("install.settings.dataCollectionDialog.whatWeCollect")), 1),
|
||||
createBaseVNode("ul", _hoisted_16, [
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("install.settings.dataCollectionDialog.collect.errorReports")), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("install.settings.dataCollectionDialog.collect.systemInfo")), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t(
|
||||
"install.settings.dataCollectionDialog.collect.userJourneyEvents"
|
||||
)), 1)
|
||||
]),
|
||||
createBaseVNode("h4", _hoisted_17, toDisplayString(_ctx.$t("install.settings.dataCollectionDialog.whatWeDoNotCollect")), 1),
|
||||
createBaseVNode("ul", _hoisted_18, [
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t(
|
||||
"install.settings.dataCollectionDialog.doNotCollect.personalInformation"
|
||||
)), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t(
|
||||
"install.settings.dataCollectionDialog.doNotCollect.workflowContents"
|
||||
)), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t(
|
||||
"install.settings.dataCollectionDialog.doNotCollect.fileSystemInformation"
|
||||
)), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t(
|
||||
"install.settings.dataCollectionDialog.doNotCollect.customNodeConfigurations"
|
||||
)), 1)
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_19, [
|
||||
createBaseVNode("a", _hoisted_20, toDisplayString(_ctx.$t("install.settings.dataCollectionDialog.viewFullPolicy")), 1)
|
||||
])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["visible", "header"])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const _imports_0 = "" + new URL("images/nvidia-logo.svg", import.meta.url).href;
|
||||
const _imports_1 = "" + new URL("images/apple-mps-logo.png", import.meta.url).href;
|
||||
const _imports_2 = "" + new URL("images/manual-configuration.svg", import.meta.url).href;
|
||||
const _hoisted_1$4 = { class: "flex flex-col gap-6 w-[600px] h-[30rem] select-none" };
|
||||
const _hoisted_2$4 = { class: "grow flex flex-col gap-4 text-neutral-300" };
|
||||
const _hoisted_3$4 = { class: "text-2xl font-semibold text-neutral-100" };
|
||||
const _hoisted_4$4 = { class: "m-1 text-neutral-400" };
|
||||
const _hoisted_5$2 = {
|
||||
key: 0,
|
||||
class: "m-1"
|
||||
};
|
||||
const _hoisted_6$2 = {
|
||||
key: 1,
|
||||
class: "m-1"
|
||||
};
|
||||
const _hoisted_7$2 = {
|
||||
key: 2,
|
||||
class: "text-neutral-300"
|
||||
};
|
||||
const _hoisted_8$2 = { class: "m-1" };
|
||||
const _hoisted_9$2 = { key: 3 };
|
||||
const _hoisted_10$2 = { class: "m-1" };
|
||||
const _hoisted_11$2 = { class: "m-1" };
|
||||
const _hoisted_12$2 = {
|
||||
for: "cpu-mode",
|
||||
class: "select-none"
|
||||
};
|
||||
const _sfc_main$5 = /* @__PURE__ */ defineComponent({
|
||||
__name: "GpuPicker",
|
||||
props: {
|
||||
"device": {
|
||||
required: true
|
||||
},
|
||||
"deviceModifiers": {}
|
||||
},
|
||||
emits: ["update:device"],
|
||||
setup(__props) {
|
||||
const { t: t2 } = useI18n();
|
||||
const cpuMode = computed({
|
||||
get: /* @__PURE__ */ __name(() => selected.value === "cpu", "get"),
|
||||
set: /* @__PURE__ */ __name((value) => {
|
||||
selected.value = value ? "cpu" : null;
|
||||
}, "set")
|
||||
});
|
||||
const selected = useModel(__props, "device");
|
||||
const electron = electronAPI();
|
||||
const platform = electron.getPlatform();
|
||||
const pickGpu = /* @__PURE__ */ __name((value) => {
|
||||
const newValue = selected.value === value ? null : value;
|
||||
selected.value = newValue;
|
||||
}, "pickGpu");
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1$4, [
|
||||
createBaseVNode("div", _hoisted_2$4, [
|
||||
createBaseVNode("h2", _hoisted_3$4, toDisplayString(_ctx.$t("install.gpuSelection.selectGpu")), 1),
|
||||
createBaseVNode("p", _hoisted_4$4, toDisplayString(_ctx.$t("install.gpuSelection.selectGpuDescription")) + ": ", 1),
|
||||
createBaseVNode("div", {
|
||||
class: normalizeClass(["flex gap-2 text-center transition-opacity", { selected: selected.value }])
|
||||
}, [
|
||||
unref(platform) !== "darwin" ? (openBlock(), createElementBlock("div", {
|
||||
key: 0,
|
||||
class: normalizeClass(["gpu-button", { selected: selected.value === "nvidia" }]),
|
||||
role: "button",
|
||||
onClick: _cache[0] || (_cache[0] = ($event) => pickGpu("nvidia"))
|
||||
}, _cache[4] || (_cache[4] = [
|
||||
createBaseVNode("img", {
|
||||
class: "m-12",
|
||||
alt: "NVIDIA logo",
|
||||
width: "196",
|
||||
height: "32",
|
||||
src: _imports_0
|
||||
}, null, -1)
|
||||
]), 2)) : createCommentVNode("", true),
|
||||
unref(platform) === "darwin" ? (openBlock(), createElementBlock("div", {
|
||||
key: 1,
|
||||
class: normalizeClass(["gpu-button", { selected: selected.value === "mps" }]),
|
||||
role: "button",
|
||||
onClick: _cache[1] || (_cache[1] = ($event) => pickGpu("mps"))
|
||||
}, _cache[5] || (_cache[5] = [
|
||||
createBaseVNode("img", {
|
||||
class: "rounded-lg hover-brighten",
|
||||
alt: "Apple Metal Performance Shaders Logo",
|
||||
width: "292",
|
||||
ratio: "",
|
||||
src: _imports_1
|
||||
}, null, -1)
|
||||
]), 2)) : createCommentVNode("", true),
|
||||
createBaseVNode("div", {
|
||||
class: normalizeClass(["gpu-button", { selected: selected.value === "unsupported" }]),
|
||||
role: "button",
|
||||
onClick: _cache[2] || (_cache[2] = ($event) => pickGpu("unsupported"))
|
||||
}, _cache[6] || (_cache[6] = [
|
||||
createBaseVNode("img", {
|
||||
class: "m-12",
|
||||
alt: "Manual configuration",
|
||||
width: "196",
|
||||
src: _imports_2
|
||||
}, null, -1)
|
||||
]), 2)
|
||||
], 2),
|
||||
selected.value === "nvidia" ? (openBlock(), createElementBlock("p", _hoisted_5$2, [
|
||||
createVNode(unref(script$3), {
|
||||
icon: "pi pi-check",
|
||||
severity: "success",
|
||||
value: "CUDA"
|
||||
}),
|
||||
createTextVNode(" " + toDisplayString(_ctx.$t("install.gpuSelection.nvidiaDescription")), 1)
|
||||
])) : createCommentVNode("", true),
|
||||
selected.value === "mps" ? (openBlock(), createElementBlock("p", _hoisted_6$2, [
|
||||
createVNode(unref(script$3), {
|
||||
icon: "pi pi-check",
|
||||
severity: "success",
|
||||
value: "MPS"
|
||||
}),
|
||||
createTextVNode(" " + toDisplayString(_ctx.$t("install.gpuSelection.mpsDescription")), 1)
|
||||
])) : createCommentVNode("", true),
|
||||
selected.value === "unsupported" ? (openBlock(), createElementBlock("div", _hoisted_7$2, [
|
||||
createBaseVNode("p", _hoisted_8$2, [
|
||||
createVNode(unref(script$3), {
|
||||
icon: "pi pi-exclamation-triangle",
|
||||
severity: "warn",
|
||||
value: unref(t2)("icon.exclamation-triangle")
|
||||
}, null, 8, ["value"]),
|
||||
createTextVNode(" " + toDisplayString(_ctx.$t("install.gpuSelection.customSkipsPython")), 1)
|
||||
]),
|
||||
createBaseVNode("ul", null, [
|
||||
createBaseVNode("li", null, [
|
||||
createBaseVNode("strong", null, toDisplayString(_ctx.$t("install.gpuSelection.customComfyNeedsPython")), 1)
|
||||
]),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("install.gpuSelection.customManualVenv")), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("install.gpuSelection.customInstallRequirements")), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("install.gpuSelection.customMayNotWork")), 1)
|
||||
])
|
||||
])) : createCommentVNode("", true),
|
||||
selected.value === "cpu" ? (openBlock(), createElementBlock("div", _hoisted_9$2, [
|
||||
createBaseVNode("p", _hoisted_10$2, [
|
||||
createVNode(unref(script$3), {
|
||||
icon: "pi pi-exclamation-triangle",
|
||||
severity: "warn",
|
||||
value: unref(t2)("icon.exclamation-triangle")
|
||||
}, null, 8, ["value"]),
|
||||
createTextVNode(" " + toDisplayString(_ctx.$t("install.gpuSelection.cpuModeDescription")), 1)
|
||||
]),
|
||||
createBaseVNode("p", _hoisted_11$2, toDisplayString(_ctx.$t("install.gpuSelection.cpuModeDescription2")), 1)
|
||||
])) : createCommentVNode("", true)
|
||||
]),
|
||||
createBaseVNode("div", {
|
||||
class: normalizeClass(["transition-opacity flex gap-3 h-0", {
|
||||
"opacity-40": selected.value && selected.value !== "cpu"
|
||||
}])
|
||||
}, [
|
||||
createVNode(unref(script), {
|
||||
modelValue: cpuMode.value,
|
||||
"onUpdate:modelValue": _cache[3] || (_cache[3] = ($event) => cpuMode.value = $event),
|
||||
inputId: "cpu-mode",
|
||||
class: "-translate-y-40"
|
||||
}, null, 8, ["modelValue"]),
|
||||
createBaseVNode("label", _hoisted_12$2, toDisplayString(_ctx.$t("install.gpuSelection.enableCpuMode")), 1)
|
||||
], 2)
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const GpuPicker = /* @__PURE__ */ _export_sfc(_sfc_main$5, [["__scopeId", "data-v-79125ff6"]]);
|
||||
const _hoisted_1$3 = { class: "flex flex-col gap-6 w-[600px]" };
|
||||
const _hoisted_2$3 = { class: "flex flex-col gap-4" };
|
||||
const _hoisted_3$3 = { class: "text-2xl font-semibold text-neutral-100" };
|
||||
const _hoisted_4$3 = { class: "text-neutral-400 my-0" };
|
||||
const _hoisted_5$1 = { class: "flex gap-2" };
|
||||
const _hoisted_6$1 = { class: "bg-neutral-800 p-4 rounded-lg" };
|
||||
const _hoisted_7$1 = { class: "text-lg font-medium mt-0 mb-3 text-neutral-100" };
|
||||
const _hoisted_8$1 = { class: "flex flex-col gap-2" };
|
||||
const _hoisted_9$1 = { class: "flex items-center gap-2" };
|
||||
const _hoisted_10$1 = { class: "text-neutral-200" };
|
||||
const _hoisted_11$1 = { class: "pi pi-info-circle" };
|
||||
const _hoisted_12$1 = { class: "flex items-center gap-2" };
|
||||
const _hoisted_13 = { class: "text-neutral-200" };
|
||||
const _hoisted_14 = { class: "pi pi-info-circle" };
|
||||
const _sfc_main$4 = /* @__PURE__ */ defineComponent({
|
||||
__name: "InstallLocationPicker",
|
||||
props: {
|
||||
"installPath": { required: true },
|
||||
"installPathModifiers": {},
|
||||
"pathError": { required: true },
|
||||
"pathErrorModifiers": {}
|
||||
},
|
||||
emits: ["update:installPath", "update:pathError"],
|
||||
setup(__props) {
|
||||
const { t: t2 } = useI18n();
|
||||
const installPath = useModel(__props, "installPath");
|
||||
const pathError = useModel(__props, "pathError");
|
||||
const pathExists = ref(false);
|
||||
const appData = ref("");
|
||||
const appPath = ref("");
|
||||
const inputTouched = ref(false);
|
||||
const electron = electronAPI();
|
||||
onMounted(async () => {
|
||||
const paths = await electron.getSystemPaths();
|
||||
appData.value = paths.appData;
|
||||
appPath.value = paths.appPath;
|
||||
installPath.value = paths.defaultInstallPath;
|
||||
await validatePath(paths.defaultInstallPath);
|
||||
});
|
||||
const validatePath = /* @__PURE__ */ __name(async (path) => {
|
||||
try {
|
||||
pathError.value = "";
|
||||
pathExists.value = false;
|
||||
const validation = await electron.validateInstallPath(path);
|
||||
if (!validation.isValid) {
|
||||
const errors = [];
|
||||
if (validation.cannotWrite) errors.push(t2("install.cannotWrite"));
|
||||
if (validation.freeSpace < validation.requiredSpace) {
|
||||
const requiredGB = validation.requiredSpace / 1024 / 1024 / 1024;
|
||||
errors.push(`${t2("install.insufficientFreeSpace")}: ${requiredGB} GB`);
|
||||
}
|
||||
if (validation.parentMissing) errors.push(t2("install.parentMissing"));
|
||||
if (validation.error)
|
||||
errors.push(`${t2("install.unhandledError")}: ${validation.error}`);
|
||||
pathError.value = errors.join("\n");
|
||||
}
|
||||
if (validation.exists) pathExists.value = true;
|
||||
} catch (error) {
|
||||
pathError.value = t2("install.pathValidationFailed");
|
||||
}
|
||||
}, "validatePath");
|
||||
const browsePath = /* @__PURE__ */ __name(async () => {
|
||||
try {
|
||||
const result = await electron.showDirectoryPicker();
|
||||
if (result) {
|
||||
installPath.value = result;
|
||||
await validatePath(result);
|
||||
}
|
||||
} catch (error) {
|
||||
pathError.value = t2("install.failedToSelectDirectory");
|
||||
}
|
||||
}, "browsePath");
|
||||
const onFocus = /* @__PURE__ */ __name(() => {
|
||||
if (!inputTouched.value) {
|
||||
inputTouched.value = true;
|
||||
return;
|
||||
}
|
||||
validatePath(installPath.value);
|
||||
}, "onFocus");
|
||||
return (_ctx, _cache) => {
|
||||
const _directive_tooltip = resolveDirective("tooltip");
|
||||
return openBlock(), createElementBlock("div", _hoisted_1$3, [
|
||||
createBaseVNode("div", _hoisted_2$3, [
|
||||
createBaseVNode("h2", _hoisted_3$3, toDisplayString(_ctx.$t("install.chooseInstallationLocation")), 1),
|
||||
createBaseVNode("p", _hoisted_4$3, toDisplayString(_ctx.$t("install.installLocationDescription")), 1),
|
||||
createBaseVNode("div", _hoisted_5$1, [
|
||||
createVNode(unref(script$6), { class: "flex-1" }, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$4), {
|
||||
modelValue: installPath.value,
|
||||
"onUpdate:modelValue": [
|
||||
_cache[0] || (_cache[0] = ($event) => installPath.value = $event),
|
||||
validatePath
|
||||
],
|
||||
class: normalizeClass(["w-full", { "p-invalid": pathError.value }]),
|
||||
onFocus
|
||||
}, null, 8, ["modelValue", "class"]),
|
||||
withDirectives(createVNode(unref(script$5), { class: "pi pi-info-circle" }, null, 512), [
|
||||
[
|
||||
_directive_tooltip,
|
||||
_ctx.$t("install.installLocationTooltip"),
|
||||
void 0,
|
||||
{ top: true }
|
||||
]
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$7), {
|
||||
icon: "pi pi-folder",
|
||||
onClick: browsePath,
|
||||
class: "w-12"
|
||||
})
|
||||
]),
|
||||
pathError.value ? (openBlock(), createBlock(unref(script$8), {
|
||||
key: 0,
|
||||
severity: "error",
|
||||
class: "whitespace-pre-line"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(pathError.value), 1)
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true),
|
||||
pathExists.value ? (openBlock(), createBlock(unref(script$8), {
|
||||
key: 1,
|
||||
severity: "warn"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(_ctx.$t("install.pathExists")), 1)
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true)
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_6$1, [
|
||||
createBaseVNode("h3", _hoisted_7$1, toDisplayString(_ctx.$t("install.systemLocations")), 1),
|
||||
createBaseVNode("div", _hoisted_8$1, [
|
||||
createBaseVNode("div", _hoisted_9$1, [
|
||||
_cache[1] || (_cache[1] = createBaseVNode("i", { class: "pi pi-folder text-neutral-400" }, null, -1)),
|
||||
_cache[2] || (_cache[2] = createBaseVNode("span", { class: "text-neutral-400" }, "App Data:", -1)),
|
||||
createBaseVNode("span", _hoisted_10$1, toDisplayString(appData.value), 1),
|
||||
withDirectives(createBaseVNode("span", _hoisted_11$1, null, 512), [
|
||||
[_directive_tooltip, _ctx.$t("install.appDataLocationTooltip")]
|
||||
])
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_12$1, [
|
||||
_cache[3] || (_cache[3] = createBaseVNode("i", { class: "pi pi-desktop text-neutral-400" }, null, -1)),
|
||||
_cache[4] || (_cache[4] = createBaseVNode("span", { class: "text-neutral-400" }, "App Path:", -1)),
|
||||
createBaseVNode("span", _hoisted_13, toDisplayString(appPath.value), 1),
|
||||
withDirectives(createBaseVNode("span", _hoisted_14, null, 512), [
|
||||
[_directive_tooltip, _ctx.$t("install.appPathLocationTooltip")]
|
||||
])
|
||||
])
|
||||
])
|
||||
])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const _hoisted_1$2 = { class: "flex flex-col gap-6 w-[600px]" };
|
||||
const _hoisted_2$2 = { class: "flex flex-col gap-4" };
|
||||
const _hoisted_3$2 = { class: "text-2xl font-semibold text-neutral-100" };
|
||||
const _hoisted_4$2 = { class: "text-neutral-400 my-0" };
|
||||
const _hoisted_5 = { class: "flex gap-2" };
|
||||
const _hoisted_6 = {
|
||||
key: 0,
|
||||
class: "flex flex-col gap-4 bg-neutral-800 p-4 rounded-lg"
|
||||
};
|
||||
const _hoisted_7 = { class: "text-lg mt-0 font-medium text-neutral-100" };
|
||||
const _hoisted_8 = { class: "flex flex-col gap-3" };
|
||||
const _hoisted_9 = ["onClick"];
|
||||
const _hoisted_10 = ["for"];
|
||||
const _hoisted_11 = { class: "text-sm text-neutral-400 my-1" };
|
||||
const _hoisted_12 = {
|
||||
key: 1,
|
||||
class: "text-neutral-400 italic"
|
||||
};
|
||||
const _sfc_main$3 = /* @__PURE__ */ defineComponent({
|
||||
__name: "MigrationPicker",
|
||||
props: {
|
||||
"sourcePath": { required: false },
|
||||
"sourcePathModifiers": {},
|
||||
"migrationItemIds": {
|
||||
required: false
|
||||
},
|
||||
"migrationItemIdsModifiers": {}
|
||||
},
|
||||
emits: ["update:sourcePath", "update:migrationItemIds"],
|
||||
setup(__props) {
|
||||
const { t: t2 } = useI18n();
|
||||
const electron = electronAPI();
|
||||
const sourcePath = useModel(__props, "sourcePath");
|
||||
const migrationItemIds = useModel(__props, "migrationItemIds");
|
||||
const migrationItems = ref(
|
||||
MigrationItems.map((item) => ({
|
||||
...item,
|
||||
selected: true
|
||||
}))
|
||||
);
|
||||
const pathError = ref("");
|
||||
const isValidSource = computed(
|
||||
() => sourcePath.value !== "" && pathError.value === ""
|
||||
);
|
||||
const validateSource = /* @__PURE__ */ __name(async (sourcePath2) => {
|
||||
if (!sourcePath2) {
|
||||
pathError.value = "";
|
||||
return;
|
||||
}
|
||||
try {
|
||||
pathError.value = "";
|
||||
const validation = await electron.validateComfyUISource(sourcePath2);
|
||||
if (!validation.isValid) pathError.value = validation.error;
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
pathError.value = t2("install.pathValidationFailed");
|
||||
}
|
||||
}, "validateSource");
|
||||
const browsePath = /* @__PURE__ */ __name(async () => {
|
||||
try {
|
||||
const result = await electron.showDirectoryPicker();
|
||||
if (result) {
|
||||
sourcePath.value = result;
|
||||
await validateSource(result);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
pathError.value = t2("install.failedToSelectDirectory");
|
||||
}
|
||||
}, "browsePath");
|
||||
watchEffect(() => {
|
||||
migrationItemIds.value = migrationItems.value.filter((item) => item.selected).map((item) => item.id);
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1$2, [
|
||||
createBaseVNode("div", _hoisted_2$2, [
|
||||
createBaseVNode("h2", _hoisted_3$2, toDisplayString(_ctx.$t("install.migrateFromExistingInstallation")), 1),
|
||||
createBaseVNode("p", _hoisted_4$2, toDisplayString(_ctx.$t("install.migrationSourcePathDescription")), 1),
|
||||
createBaseVNode("div", _hoisted_5, [
|
||||
createVNode(unref(script$4), {
|
||||
modelValue: sourcePath.value,
|
||||
"onUpdate:modelValue": [
|
||||
_cache[0] || (_cache[0] = ($event) => sourcePath.value = $event),
|
||||
validateSource
|
||||
],
|
||||
placeholder: "Select existing ComfyUI installation (optional)",
|
||||
class: normalizeClass(["flex-1", { "p-invalid": pathError.value }])
|
||||
}, null, 8, ["modelValue", "class"]),
|
||||
createVNode(unref(script$7), {
|
||||
icon: "pi pi-folder",
|
||||
onClick: browsePath,
|
||||
class: "w-12"
|
||||
})
|
||||
]),
|
||||
pathError.value ? (openBlock(), createBlock(unref(script$8), {
|
||||
key: 0,
|
||||
severity: "error"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(pathError.value), 1)
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true)
|
||||
]),
|
||||
isValidSource.value ? (openBlock(), createElementBlock("div", _hoisted_6, [
|
||||
createBaseVNode("h3", _hoisted_7, toDisplayString(_ctx.$t("install.selectItemsToMigrate")), 1),
|
||||
createBaseVNode("div", _hoisted_8, [
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(migrationItems.value, (item) => {
|
||||
return openBlock(), createElementBlock("div", {
|
||||
key: item.id,
|
||||
class: "flex items-center gap-3 p-2 hover:bg-neutral-700 rounded",
|
||||
onClick: /* @__PURE__ */ __name(($event) => item.selected = !item.selected, "onClick")
|
||||
}, [
|
||||
createVNode(unref(script$9), {
|
||||
modelValue: item.selected,
|
||||
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => item.selected = $event, "onUpdate:modelValue"),
|
||||
inputId: item.id,
|
||||
binary: true,
|
||||
onClick: _cache[1] || (_cache[1] = withModifiers(() => {
|
||||
}, ["stop"]))
|
||||
}, null, 8, ["modelValue", "onUpdate:modelValue", "inputId"]),
|
||||
createBaseVNode("div", null, [
|
||||
createBaseVNode("label", {
|
||||
for: item.id,
|
||||
class: "text-neutral-200 font-medium"
|
||||
}, toDisplayString(item.label), 9, _hoisted_10),
|
||||
createBaseVNode("p", _hoisted_11, toDisplayString(item.description), 1)
|
||||
])
|
||||
], 8, _hoisted_9);
|
||||
}), 128))
|
||||
])
|
||||
])) : (openBlock(), createElementBlock("div", _hoisted_12, toDisplayString(_ctx.$t("install.migrationOptional")), 1))
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const _hoisted_1$1 = { class: "flex flex-col items-center gap-4" };
|
||||
const _hoisted_2$1 = { class: "w-full" };
|
||||
const _hoisted_3$1 = { class: "text-lg font-medium text-neutral-100" };
|
||||
const _hoisted_4$1 = { class: "text-sm text-neutral-400 mt-1" };
|
||||
const _sfc_main$2 = /* @__PURE__ */ defineComponent({
|
||||
__name: "MirrorItem",
|
||||
props: /* @__PURE__ */ mergeModels({
|
||||
item: {}
|
||||
}, {
|
||||
"modelValue": { required: true },
|
||||
"modelModifiers": {}
|
||||
}),
|
||||
emits: /* @__PURE__ */ mergeModels(["state-change"], ["update:modelValue"]),
|
||||
setup(__props, { emit: __emit }) {
|
||||
const emit = __emit;
|
||||
const modelValue = useModel(__props, "modelValue");
|
||||
const validationState = ref(ValidationState.IDLE);
|
||||
const normalizedSettingId = computed(() => {
|
||||
return normalizeI18nKey(__props.item.settingId);
|
||||
});
|
||||
onMounted(() => {
|
||||
modelValue.value = __props.item.mirror;
|
||||
});
|
||||
watch(validationState, (newState) => {
|
||||
emit("state-change", newState);
|
||||
if (newState === ValidationState.INVALID && modelValue.value === __props.item.mirror) {
|
||||
modelValue.value = __props.item.fallbackMirror;
|
||||
}
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1$1, [
|
||||
createBaseVNode("div", _hoisted_2$1, [
|
||||
createBaseVNode("h3", _hoisted_3$1, toDisplayString(_ctx.$t(`settings.${normalizedSettingId.value}.name`)), 1),
|
||||
createBaseVNode("p", _hoisted_4$1, toDisplayString(_ctx.$t(`settings.${normalizedSettingId.value}.tooltip`)), 1)
|
||||
]),
|
||||
createVNode(_sfc_main$7, {
|
||||
modelValue: modelValue.value,
|
||||
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => modelValue.value = $event),
|
||||
"validate-url-fn": /* @__PURE__ */ __name((mirror) => unref(checkMirrorReachable)(mirror + (_ctx.item.validationPathSuffix ?? "")), "validate-url-fn"),
|
||||
onStateChange: _cache[1] || (_cache[1] = ($event) => validationState.value = $event)
|
||||
}, null, 8, ["modelValue", "validate-url-fn"])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
||||
__name: "MirrorsConfiguration",
|
||||
props: /* @__PURE__ */ mergeModels({
|
||||
device: {}
|
||||
}, {
|
||||
"pythonMirror": { required: true },
|
||||
"pythonMirrorModifiers": {},
|
||||
"pypiMirror": { required: true },
|
||||
"pypiMirrorModifiers": {},
|
||||
"torchMirror": { required: true },
|
||||
"torchMirrorModifiers": {}
|
||||
}),
|
||||
emits: ["update:pythonMirror", "update:pypiMirror", "update:torchMirror"],
|
||||
setup(__props) {
|
||||
const showMirrorInputs = ref(false);
|
||||
const pythonMirror = useModel(__props, "pythonMirror");
|
||||
const pypiMirror = useModel(__props, "pypiMirror");
|
||||
const torchMirror = useModel(__props, "torchMirror");
|
||||
const getTorchMirrorItem = /* @__PURE__ */ __name((device) => {
|
||||
const settingId = "Comfy-Desktop.UV.TorchInstallMirror";
|
||||
switch (device) {
|
||||
case "mps":
|
||||
return {
|
||||
settingId,
|
||||
mirror: NIGHTLY_CPU_TORCH_URL,
|
||||
fallbackMirror: NIGHTLY_CPU_TORCH_URL
|
||||
};
|
||||
case "nvidia":
|
||||
return {
|
||||
settingId,
|
||||
mirror: CUDA_TORCH_URL,
|
||||
fallbackMirror: CUDA_TORCH_URL
|
||||
};
|
||||
case "cpu":
|
||||
default:
|
||||
return {
|
||||
settingId,
|
||||
mirror: PYPI_MIRROR.mirror,
|
||||
fallbackMirror: PYPI_MIRROR.fallbackMirror
|
||||
};
|
||||
}
|
||||
}, "getTorchMirrorItem");
|
||||
const userIsInChina = ref(false);
|
||||
onMounted(async () => {
|
||||
userIsInChina.value = await isInChina();
|
||||
});
|
||||
const useFallbackMirror = /* @__PURE__ */ __name((mirror) => ({
|
||||
...mirror,
|
||||
mirror: mirror.fallbackMirror
|
||||
}), "useFallbackMirror");
|
||||
const mirrors = computed(
|
||||
() => [
|
||||
[PYTHON_MIRROR, pythonMirror],
|
||||
[PYPI_MIRROR, pypiMirror],
|
||||
[getTorchMirrorItem(__props.device), torchMirror]
|
||||
].map(([item, modelValue]) => [
|
||||
userIsInChina.value ? useFallbackMirror(item) : item,
|
||||
modelValue
|
||||
])
|
||||
);
|
||||
const validationStates = ref(
|
||||
mirrors.value.map(() => ValidationState.IDLE)
|
||||
);
|
||||
const validationState = computed(() => {
|
||||
return mergeValidationStates(validationStates.value);
|
||||
});
|
||||
const validationStateTooltip = computed(() => {
|
||||
switch (validationState.value) {
|
||||
case ValidationState.INVALID:
|
||||
return t("install.settings.mirrorsUnreachable");
|
||||
case ValidationState.VALID:
|
||||
return t("install.settings.mirrorsReachable");
|
||||
default:
|
||||
return t("install.settings.checkingMirrors");
|
||||
}
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
const _directive_tooltip = resolveDirective("tooltip");
|
||||
return openBlock(), createBlock(unref(script$a), {
|
||||
header: _ctx.$t("install.settings.mirrorSettings"),
|
||||
toggleable: "",
|
||||
collapsed: !showMirrorInputs.value,
|
||||
"pt:root": "bg-neutral-800 border-none w-[600px]"
|
||||
}, {
|
||||
icons: withCtx(() => [
|
||||
withDirectives(createBaseVNode("i", {
|
||||
class: normalizeClass({
|
||||
"pi pi-spin pi-spinner text-neutral-400": validationState.value === unref(ValidationState).LOADING,
|
||||
"pi pi-check text-green-500": validationState.value === unref(ValidationState).VALID,
|
||||
"pi pi-times text-red-500": validationState.value === unref(ValidationState).INVALID
|
||||
})
|
||||
}, null, 2), [
|
||||
[_directive_tooltip, validationStateTooltip.value]
|
||||
])
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(mirrors.value, ([item, modelValue], index) => {
|
||||
return openBlock(), createElementBlock(Fragment, {
|
||||
key: item.settingId + item.mirror
|
||||
}, [
|
||||
index > 0 ? (openBlock(), createBlock(unref(script$1), { key: 0 })) : createCommentVNode("", true),
|
||||
createVNode(_sfc_main$2, {
|
||||
item,
|
||||
modelValue: modelValue.value,
|
||||
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => modelValue.value = $event, "onUpdate:modelValue"),
|
||||
onStateChange: /* @__PURE__ */ __name(($event) => validationStates.value[index] = $event, "onStateChange")
|
||||
}, null, 8, ["item", "modelValue", "onUpdate:modelValue", "onStateChange"])
|
||||
], 64);
|
||||
}), 128))
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["header", "collapsed"]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const _hoisted_1 = { class: "flex pt-6 justify-end" };
|
||||
const _hoisted_2 = { class: "flex pt-6 justify-between" };
|
||||
const _hoisted_3 = { class: "flex pt-6 justify-between" };
|
||||
const _hoisted_4 = { class: "flex mt-6 justify-between" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "InstallView",
|
||||
setup(__props) {
|
||||
const device = ref(null);
|
||||
const installPath = ref("");
|
||||
const pathError = ref("");
|
||||
const migrationSourcePath = ref("");
|
||||
const migrationItemIds = ref([]);
|
||||
const autoUpdate = ref(true);
|
||||
const allowMetrics = ref(true);
|
||||
const pythonMirror = ref("");
|
||||
const pypiMirror = ref("");
|
||||
const torchMirror = ref("");
|
||||
const highestStep = ref(0);
|
||||
const handleStepChange = /* @__PURE__ */ __name((value) => {
|
||||
setHighestStep(value);
|
||||
electronAPI().Events.trackEvent("install_stepper_change", {
|
||||
step: value
|
||||
});
|
||||
}, "handleStepChange");
|
||||
const setHighestStep = /* @__PURE__ */ __name((value) => {
|
||||
const int = typeof value === "number" ? value : parseInt(value, 10);
|
||||
if (!isNaN(int) && int > highestStep.value) highestStep.value = int;
|
||||
}, "setHighestStep");
|
||||
const hasError = computed(() => pathError.value !== "");
|
||||
const noGpu = computed(() => typeof device.value !== "string");
|
||||
const electron = electronAPI();
|
||||
const router = useRouter();
|
||||
const install = /* @__PURE__ */ __name(() => {
|
||||
const options = {
|
||||
installPath: installPath.value,
|
||||
autoUpdate: autoUpdate.value,
|
||||
allowMetrics: allowMetrics.value,
|
||||
migrationSourcePath: migrationSourcePath.value,
|
||||
migrationItemIds: toRaw(migrationItemIds.value),
|
||||
pythonMirror: pythonMirror.value,
|
||||
pypiMirror: pypiMirror.value,
|
||||
torchMirror: torchMirror.value,
|
||||
device: device.value
|
||||
};
|
||||
electron.installComfyUI(options);
|
||||
const nextPage = options.device === "unsupported" ? "/manual-configuration" : "/server-start";
|
||||
router.push(nextPage);
|
||||
}, "install");
|
||||
onMounted(async () => {
|
||||
if (!electron) return;
|
||||
const detectedGpu = await electron.Config.getDetectedGpu();
|
||||
if (detectedGpu === "mps" || detectedGpu === "nvidia") {
|
||||
device.value = detectedGpu;
|
||||
}
|
||||
electronAPI().Events.trackEvent("install_stepper_change", {
|
||||
step: "0",
|
||||
gpu: detectedGpu
|
||||
});
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createBlock(_sfc_main$8, { dark: "" }, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$f), {
|
||||
class: "h-full p-8 2xl:p-16",
|
||||
value: "0",
|
||||
"onUpdate:value": handleStepChange
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$b), { class: "select-none" }, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$c), { value: "0" }, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(_ctx.$t("install.gpu")), 1)
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$c), {
|
||||
value: "1",
|
||||
disabled: noGpu.value
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(_ctx.$t("install.installLocation")), 1)
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["disabled"]),
|
||||
createVNode(unref(script$c), {
|
||||
value: "2",
|
||||
disabled: noGpu.value || hasError.value || highestStep.value < 1
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(_ctx.$t("install.migration")), 1)
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["disabled"]),
|
||||
createVNode(unref(script$c), {
|
||||
value: "3",
|
||||
disabled: noGpu.value || hasError.value || highestStep.value < 2
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(_ctx.$t("install.desktopSettings")), 1)
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["disabled"])
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$d), null, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$e), { value: "0" }, {
|
||||
default: withCtx(({ activateCallback }) => [
|
||||
createVNode(GpuPicker, {
|
||||
device: device.value,
|
||||
"onUpdate:device": _cache[0] || (_cache[0] = ($event) => device.value = $event)
|
||||
}, null, 8, ["device"]),
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
createVNode(unref(script$7), {
|
||||
label: _ctx.$t("g.next"),
|
||||
icon: "pi pi-arrow-right",
|
||||
iconPos: "right",
|
||||
onClick: /* @__PURE__ */ __name(($event) => activateCallback("1"), "onClick"),
|
||||
disabled: typeof device.value !== "string"
|
||||
}, null, 8, ["label", "onClick", "disabled"])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$e), { value: "1" }, {
|
||||
default: withCtx(({ activateCallback }) => [
|
||||
createVNode(_sfc_main$4, {
|
||||
installPath: installPath.value,
|
||||
"onUpdate:installPath": _cache[1] || (_cache[1] = ($event) => installPath.value = $event),
|
||||
pathError: pathError.value,
|
||||
"onUpdate:pathError": _cache[2] || (_cache[2] = ($event) => pathError.value = $event)
|
||||
}, null, 8, ["installPath", "pathError"]),
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createVNode(unref(script$7), {
|
||||
label: _ctx.$t("g.back"),
|
||||
severity: "secondary",
|
||||
icon: "pi pi-arrow-left",
|
||||
onClick: /* @__PURE__ */ __name(($event) => activateCallback("0"), "onClick")
|
||||
}, null, 8, ["label", "onClick"]),
|
||||
createVNode(unref(script$7), {
|
||||
label: _ctx.$t("g.next"),
|
||||
icon: "pi pi-arrow-right",
|
||||
iconPos: "right",
|
||||
onClick: /* @__PURE__ */ __name(($event) => activateCallback("2"), "onClick"),
|
||||
disabled: pathError.value !== ""
|
||||
}, null, 8, ["label", "onClick", "disabled"])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$e), { value: "2" }, {
|
||||
default: withCtx(({ activateCallback }) => [
|
||||
createVNode(_sfc_main$3, {
|
||||
sourcePath: migrationSourcePath.value,
|
||||
"onUpdate:sourcePath": _cache[3] || (_cache[3] = ($event) => migrationSourcePath.value = $event),
|
||||
migrationItemIds: migrationItemIds.value,
|
||||
"onUpdate:migrationItemIds": _cache[4] || (_cache[4] = ($event) => migrationItemIds.value = $event)
|
||||
}, null, 8, ["sourcePath", "migrationItemIds"]),
|
||||
createBaseVNode("div", _hoisted_3, [
|
||||
createVNode(unref(script$7), {
|
||||
label: _ctx.$t("g.back"),
|
||||
severity: "secondary",
|
||||
icon: "pi pi-arrow-left",
|
||||
onClick: /* @__PURE__ */ __name(($event) => activateCallback("1"), "onClick")
|
||||
}, null, 8, ["label", "onClick"]),
|
||||
createVNode(unref(script$7), {
|
||||
label: _ctx.$t("g.next"),
|
||||
icon: "pi pi-arrow-right",
|
||||
iconPos: "right",
|
||||
onClick: /* @__PURE__ */ __name(($event) => activateCallback("3"), "onClick")
|
||||
}, null, 8, ["label", "onClick"])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$e), { value: "3" }, {
|
||||
default: withCtx(({ activateCallback }) => [
|
||||
createVNode(_sfc_main$6, {
|
||||
autoUpdate: autoUpdate.value,
|
||||
"onUpdate:autoUpdate": _cache[5] || (_cache[5] = ($event) => autoUpdate.value = $event),
|
||||
allowMetrics: allowMetrics.value,
|
||||
"onUpdate:allowMetrics": _cache[6] || (_cache[6] = ($event) => allowMetrics.value = $event)
|
||||
}, null, 8, ["autoUpdate", "allowMetrics"]),
|
||||
createVNode(_sfc_main$1, {
|
||||
device: device.value,
|
||||
pythonMirror: pythonMirror.value,
|
||||
"onUpdate:pythonMirror": _cache[7] || (_cache[7] = ($event) => pythonMirror.value = $event),
|
||||
pypiMirror: pypiMirror.value,
|
||||
"onUpdate:pypiMirror": _cache[8] || (_cache[8] = ($event) => pypiMirror.value = $event),
|
||||
torchMirror: torchMirror.value,
|
||||
"onUpdate:torchMirror": _cache[9] || (_cache[9] = ($event) => torchMirror.value = $event),
|
||||
class: "mt-6"
|
||||
}, null, 8, ["device", "pythonMirror", "pypiMirror", "torchMirror"]),
|
||||
createBaseVNode("div", _hoisted_4, [
|
||||
createVNode(unref(script$7), {
|
||||
label: _ctx.$t("g.back"),
|
||||
severity: "secondary",
|
||||
icon: "pi pi-arrow-left",
|
||||
onClick: /* @__PURE__ */ __name(($event) => activateCallback("2"), "onClick")
|
||||
}, null, 8, ["label", "onClick"]),
|
||||
createVNode(unref(script$7), {
|
||||
label: _ctx.$t("g.install"),
|
||||
icon: "pi pi-check",
|
||||
iconPos: "right",
|
||||
disabled: hasError.value,
|
||||
onClick: _cache[10] || (_cache[10] = ($event) => install())
|
||||
}, null, 8, ["label", "disabled"])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
const InstallView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-cd6731d2"]]);
|
||||
export {
|
||||
InstallView as default
|
||||
};
|
||||
//# sourceMappingURL=InstallView-DW9xwU_F.js.map
|
||||
32
web/assets/InstallView-CwQdoH-C.css → web/assets/InstallView-DbJ2cGfL.css
generated
vendored
32
web/assets/InstallView-CwQdoH-C.css → web/assets/InstallView-DbJ2cGfL.css
generated
vendored
@@ -1,18 +1,20 @@
|
||||
|
||||
:root {
|
||||
.p-tag[data-v-79125ff6] {
|
||||
--p-tag-gap: 0.5rem;
|
||||
}
|
||||
.hover-brighten {
|
||||
&[data-v-79125ff6] {
|
||||
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke;
|
||||
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
transition-duration: 150ms;
|
||||
transition-property: filter, box-shadow;
|
||||
&:hover {
|
||||
}
|
||||
&[data-v-79125ff6]:hover {
|
||||
filter: brightness(107%) contrast(105%);
|
||||
box-shadow: 0 0 0.25rem #ffffff79;
|
||||
}
|
||||
}
|
||||
.p-accordioncontent-content {
|
||||
.p-accordioncontent-content[data-v-79125ff6] {
|
||||
border-radius: 0.5rem;
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(23 23 23 / var(--tw-bg-opacity));
|
||||
@@ -21,14 +23,14 @@
|
||||
transition-duration: 150ms;
|
||||
}
|
||||
div.selected {
|
||||
.gpu-button:not(.selected) {
|
||||
.gpu-button[data-v-79125ff6]:not(.selected) {
|
||||
opacity: 0.5;
|
||||
}
|
||||
.gpu-button:not(.selected):hover {
|
||||
.gpu-button[data-v-79125ff6]:not(.selected):hover {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
.gpu-button {
|
||||
.gpu-button[data-v-79125ff6] {
|
||||
margin: 0px;
|
||||
display: flex;
|
||||
width: 50%;
|
||||
@@ -43,37 +45,37 @@ div.selected {
|
||||
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
transition-duration: 150ms;
|
||||
}
|
||||
.gpu-button:hover {
|
||||
.gpu-button[data-v-79125ff6]:hover {
|
||||
--tw-bg-opacity: 0.75;
|
||||
}
|
||||
.gpu-button {
|
||||
&.selected {
|
||||
&.selected[data-v-79125ff6] {
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(64 64 64 / var(--tw-bg-opacity));
|
||||
}
|
||||
&.selected {
|
||||
&.selected[data-v-79125ff6] {
|
||||
--tw-bg-opacity: 0.5;
|
||||
}
|
||||
&.selected {
|
||||
&.selected[data-v-79125ff6] {
|
||||
opacity: 1;
|
||||
}
|
||||
&.selected:hover {
|
||||
&.selected[data-v-79125ff6]:hover {
|
||||
--tw-bg-opacity: 0.6;
|
||||
}
|
||||
}
|
||||
.disabled {
|
||||
.disabled[data-v-79125ff6] {
|
||||
pointer-events: none;
|
||||
opacity: 0.4;
|
||||
}
|
||||
.p-card-header {
|
||||
.p-card-header[data-v-79125ff6] {
|
||||
flex-grow: 1;
|
||||
text-align: center;
|
||||
}
|
||||
.p-card-body {
|
||||
.p-card-body[data-v-79125ff6] {
|
||||
padding-top: 0px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
[data-v-de33872d] .p-steppanel {
|
||||
[data-v-cd6731d2] .p-steppanel {
|
||||
background-color: transparent
|
||||
}
|
||||
8
web/assets/KeybindingPanel-CDYVPYDp.css
generated
vendored
Normal file
8
web/assets/KeybindingPanel-CDYVPYDp.css
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
[data-v-8454e24f] .p-datatable-tbody > tr > td {
|
||||
padding: 0.25rem;
|
||||
min-height: 2rem
|
||||
}
|
||||
[data-v-8454e24f] .p-datatable-row-selected .actions,[data-v-8454e24f] .p-datatable-selectable-row:hover .actions {
|
||||
visibility: visible
|
||||
}
|
||||
8
web/assets/KeybindingPanel-DvrUYZ4S.css
generated
vendored
8
web/assets/KeybindingPanel-DvrUYZ4S.css
generated
vendored
@@ -1,8 +0,0 @@
|
||||
|
||||
[data-v-2554ab36] .p-datatable-tbody > tr > td {
|
||||
padding: 0.25rem;
|
||||
min-height: 2rem
|
||||
}
|
||||
[data-v-2554ab36] .p-datatable-row-selected .actions,[data-v-2554ab36] .p-datatable-selectable-row:hover .actions {
|
||||
visibility: visible
|
||||
}
|
||||
30
web/assets/KeybindingPanel-Dc3C4lG1.js → web/assets/KeybindingPanel-oavhFdkz.js
generated
vendored
30
web/assets/KeybindingPanel-Dc3C4lG1.js → web/assets/KeybindingPanel-oavhFdkz.js
generated
vendored
@@ -1,10 +1,9 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, c as computed, o as openBlock, f as createElementBlock, F as Fragment, E as renderList, N as createVNode, M as withCtx, aE as createTextVNode, X as toDisplayString, j as unref, aI as script, I as createCommentVNode, ab as ref, cn as FilterMatchMode, a$ as useKeybindingStore, a2 as useCommandStore, a1 as useI18n, af as normalizeI18nKey, w as watchEffect, bs as useToast, r as resolveDirective, k as createBlock, co as SearchBox, H as createBaseVNode, l as script$2, av as script$4, bM as withModifiers, bZ as script$5, aP as script$6, i as withDirectives, cp as _sfc_main$2, aL as pushScopeId, aM as popScopeId, cq as KeyComboImpl, cr as KeybindingImpl, _ as _export_sfc } from "./index-DjNHn37O.js";
|
||||
import { s as script$1, a as script$3 } from "./index-B5F0uxTQ.js";
|
||||
import { u as useKeybindingService } from "./keybindingService-Bx7YdkXn.js";
|
||||
import "./index-B-aVupP5.js";
|
||||
import "./index-5HFeZax4.js";
|
||||
import { d as defineComponent, c as computed, o as openBlock, f as createElementBlock, F as Fragment, D as renderList, k as createVNode, z as withCtx, a8 as createTextVNode, E as toDisplayString, j as unref, a5 as script, B as createCommentVNode, T as ref, dx as FilterMatchMode, ao as useKeybindingStore, J as useCommandStore, I as useI18n, X as normalizeI18nKey, w as watchEffect, aV as useToast, r as resolveDirective, y as createBlock, dy as SearchBox, m as createBaseVNode, l as script$2, bk as script$4, as as withModifiers, bn as script$5, ac as script$6, i as withDirectives, dz as _sfc_main$2, dA as KeyComboImpl, dB as KeybindingImpl, _ as _export_sfc } from "./index-Bv0b06LE.js";
|
||||
import { g as script$1, h as script$3 } from "./index-CgMyWf7n.js";
|
||||
import { u as useKeybindingService } from "./keybindingService-DyjX-nxF.js";
|
||||
import "./index-Dzu9WL4p.js";
|
||||
const _hoisted_1$1 = {
|
||||
key: 0,
|
||||
class: "px-2"
|
||||
@@ -37,7 +36,6 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-2554ab36"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
const _hoisted_1 = { class: "actions invisible flex flex-row" };
|
||||
const _hoisted_2 = ["title"];
|
||||
const _hoisted_3 = { key: 1 };
|
||||
@@ -98,6 +96,16 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
}
|
||||
__name(removeKeybinding, "removeKeybinding");
|
||||
function captureKeybinding(event) {
|
||||
if (!event.shiftKey && !event.altKey && !event.ctrlKey && !event.metaKey) {
|
||||
switch (event.key) {
|
||||
case "Escape":
|
||||
cancelEdit();
|
||||
return;
|
||||
case "Enter":
|
||||
saveKeybinding();
|
||||
return;
|
||||
}
|
||||
}
|
||||
const keyCombo = KeyComboImpl.fromEvent(event);
|
||||
newBindingKeyCombo.value = keyCombo;
|
||||
}
|
||||
@@ -153,7 +161,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
value: commandsData.value,
|
||||
selection: selectedCommandData.value,
|
||||
"onUpdate:selection": _cache[1] || (_cache[1] = ($event) => selectedCommandData.value = $event),
|
||||
"global-filter-fields": ["id"],
|
||||
"global-filter-fields": ["id", "label"],
|
||||
filters: filters.value,
|
||||
selectionMode: "single",
|
||||
stripedRows: "",
|
||||
@@ -218,7 +226,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
visible: editDialogVisible.value,
|
||||
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => editDialogVisible.value = $event),
|
||||
modal: "",
|
||||
header: currentEditingCommand.value?.id,
|
||||
header: currentEditingCommand.value?.label,
|
||||
onHide: cancelEdit
|
||||
}, {
|
||||
footer: withCtx(() => [
|
||||
@@ -248,7 +256,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
severity: "error"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(" Keybinding already exists on "),
|
||||
_cache[3] || (_cache[3] = createTextVNode(" Keybinding already exists on ")),
|
||||
createVNode(unref(script), {
|
||||
severity: "secondary",
|
||||
value: existingKeybindingOnCombo.value.commandId
|
||||
@@ -277,8 +285,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-2554ab36"]]);
|
||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-8454e24f"]]);
|
||||
export {
|
||||
KeybindingPanel as default
|
||||
};
|
||||
//# sourceMappingURL=KeybindingPanel-Dc3C4lG1.js.map
|
||||
//# sourceMappingURL=KeybindingPanel-oavhFdkz.js.map
|
||||
25635
web/assets/MaintenanceView-Bh8OZpgl.js
generated
vendored
Normal file
25635
web/assets/MaintenanceView-Bh8OZpgl.js
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
87
web/assets/MaintenanceView-DEJCj8SR.css
generated
vendored
Normal file
87
web/assets/MaintenanceView-DEJCj8SR.css
generated
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
|
||||
.task-card-ok[data-v-c3bd7658] {
|
||||
|
||||
position: absolute;
|
||||
|
||||
right: -1rem;
|
||||
|
||||
bottom: -1rem;
|
||||
|
||||
grid-column: 1 / -1;
|
||||
|
||||
grid-row: 1 / -1;
|
||||
|
||||
--tw-text-opacity: 1;
|
||||
|
||||
color: rgb(150 206 76 / var(--tw-text-opacity));
|
||||
|
||||
opacity: 1;
|
||||
|
||||
transition-property: opacity;
|
||||
|
||||
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
|
||||
transition-duration: 150ms;
|
||||
|
||||
font-size: 4rem;
|
||||
text-shadow: 0.25rem 0 0.5rem black;
|
||||
z-index: 10;
|
||||
}
|
||||
.p-card {
|
||||
&[data-v-c3bd7658] {
|
||||
|
||||
transition-property: opacity;
|
||||
|
||||
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
|
||||
transition-duration: 150ms;
|
||||
|
||||
--p-card-background: var(--p-button-secondary-background);
|
||||
opacity: 0.9;
|
||||
}
|
||||
&.opacity-65[data-v-c3bd7658] {
|
||||
opacity: 0.4;
|
||||
}
|
||||
&[data-v-c3bd7658]:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
[data-v-c3bd7658] .p-card-header {
|
||||
z-index: 0;
|
||||
}
|
||||
[data-v-c3bd7658] .p-card-body {
|
||||
z-index: 1;
|
||||
flex-grow: 1;
|
||||
justify-content: space-between;
|
||||
}
|
||||
.task-div {
|
||||
> i[data-v-c3bd7658] {
|
||||
pointer-events: none;
|
||||
}
|
||||
&:hover > i[data-v-c3bd7658] {
|
||||
opacity: 0.2;
|
||||
}
|
||||
}
|
||||
|
||||
[data-v-dd50a7dd] .p-tag {
|
||||
--p-tag-gap: 0.375rem;
|
||||
}
|
||||
.backspan[data-v-dd50a7dd]::before {
|
||||
position: absolute;
|
||||
margin: 0px;
|
||||
color: var(--p-text-muted-color);
|
||||
font-family: 'primeicons';
|
||||
top: -2rem;
|
||||
right: -2rem;
|
||||
speak: none;
|
||||
font-style: normal;
|
||||
font-weight: normal;
|
||||
font-variant: normal;
|
||||
text-transform: none;
|
||||
line-height: 1;
|
||||
display: inline-block;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
opacity: 0.02;
|
||||
font-size: min(14rem, 90vw);
|
||||
z-index: 0;
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
:root {
|
||||
.p-tag[data-v-dc169863] {
|
||||
--p-tag-gap: 0.5rem;
|
||||
}
|
||||
.comfy-installer {
|
||||
.comfy-installer[data-v-dc169863] {
|
||||
margin-top: max(1rem, max(0px, calc((100vh - 42rem) * 0.5)));
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, a1 as useI18n, ab as ref, m as onMounted, o as openBlock, k as createBlock, M as withCtx, H as createBaseVNode, X as toDisplayString, N as createVNode, j as unref, aI as script, l as script$2, c0 as electronAPI } from "./index-DjNHn37O.js";
|
||||
import { s as script$1 } from "./index-jXPKy3pP.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BNGF4K22.js";
|
||||
import "./index-5HFeZax4.js";
|
||||
import { d as defineComponent, I as useI18n, T as ref, p as onMounted, o as openBlock, y as createBlock, z as withCtx, m as createBaseVNode, E as toDisplayString, k as createVNode, j as unref, a5 as script, b3 as script$1, l as script$2, b9 as electronAPI, _ as _export_sfc } from "./index-Bv0b06LE.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BTbuZf5t.js";
|
||||
const _hoisted_1 = { class: "comfy-installer grow flex flex-col gap-4 text-neutral-300 max-w-110" };
|
||||
const _hoisted_2 = { class: "text-2xl font-semibold text-neutral-100" };
|
||||
const _hoisted_3 = { class: "m-1 text-neutral-300" };
|
||||
@@ -69,7 +67,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
const ManualConfigurationView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-dc169863"]]);
|
||||
export {
|
||||
_sfc_main as default
|
||||
ManualConfigurationView as default
|
||||
};
|
||||
//# sourceMappingURL=ManualConfigurationView-Bi_qHE-n.js.map
|
||||
//# sourceMappingURL=ManualConfigurationView-DTLyJ3VG.js.map
|
||||
86
web/assets/MetricsConsentView-C80fk2cl.js
generated
vendored
Normal file
86
web/assets/MetricsConsentView-C80fk2cl.js
generated
vendored
Normal file
@@ -0,0 +1,86 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BTbuZf5t.js";
|
||||
import { d as defineComponent, aV as useToast, I as useI18n, T as ref, bi as useRouter, o as openBlock, y as createBlock, z as withCtx, m as createBaseVNode, E as toDisplayString, a8 as createTextVNode, k as createVNode, j as unref, br as script, l as script$1, b9 as electronAPI } from "./index-Bv0b06LE.js";
|
||||
const _hoisted_1 = { class: "h-full p-8 2xl:p-16 flex flex-col items-center justify-center" };
|
||||
const _hoisted_2 = { class: "bg-neutral-800 rounded-lg shadow-lg p-6 w-full max-w-[600px] flex flex-col gap-6" };
|
||||
const _hoisted_3 = { class: "text-3xl font-semibold text-neutral-100" };
|
||||
const _hoisted_4 = { class: "text-neutral-400" };
|
||||
const _hoisted_5 = { class: "text-neutral-400" };
|
||||
const _hoisted_6 = {
|
||||
href: "https://comfy.org/privacy",
|
||||
target: "_blank",
|
||||
class: "text-blue-400 hover:text-blue-300 underline"
|
||||
};
|
||||
const _hoisted_7 = { class: "flex items-center gap-4" };
|
||||
const _hoisted_8 = {
|
||||
id: "metricsDescription",
|
||||
class: "text-neutral-100"
|
||||
};
|
||||
const _hoisted_9 = { class: "flex pt-6 justify-end" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "MetricsConsentView",
|
||||
setup(__props) {
|
||||
const toast = useToast();
|
||||
const { t } = useI18n();
|
||||
const allowMetrics = ref(true);
|
||||
const router = useRouter();
|
||||
const isUpdating = ref(false);
|
||||
const updateConsent = /* @__PURE__ */ __name(async () => {
|
||||
isUpdating.value = true;
|
||||
try {
|
||||
await electronAPI().setMetricsConsent(allowMetrics.value);
|
||||
} catch (error) {
|
||||
toast.add({
|
||||
severity: "error",
|
||||
summary: t("install.errorUpdatingConsent"),
|
||||
detail: t("install.errorUpdatingConsentDetail"),
|
||||
life: 3e3
|
||||
});
|
||||
} finally {
|
||||
isUpdating.value = false;
|
||||
}
|
||||
router.push("/");
|
||||
}, "updateConsent");
|
||||
return (_ctx, _cache) => {
|
||||
const _component_BaseViewTemplate = _sfc_main$1;
|
||||
return openBlock(), createBlock(_component_BaseViewTemplate, { dark: "" }, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createBaseVNode("h2", _hoisted_3, toDisplayString(_ctx.$t("install.helpImprove")), 1),
|
||||
createBaseVNode("p", _hoisted_4, toDisplayString(_ctx.$t("install.updateConsent")), 1),
|
||||
createBaseVNode("p", _hoisted_5, [
|
||||
createTextVNode(toDisplayString(_ctx.$t("install.moreInfo")) + " ", 1),
|
||||
createBaseVNode("a", _hoisted_6, toDisplayString(_ctx.$t("install.privacyPolicy")), 1),
|
||||
_cache[1] || (_cache[1] = createTextVNode(". "))
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_7, [
|
||||
createVNode(unref(script), {
|
||||
modelValue: allowMetrics.value,
|
||||
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => allowMetrics.value = $event),
|
||||
"aria-describedby": "metricsDescription"
|
||||
}, null, 8, ["modelValue"]),
|
||||
createBaseVNode("span", _hoisted_8, toDisplayString(allowMetrics.value ? _ctx.$t("install.metricsEnabled") : _ctx.$t("install.metricsDisabled")), 1)
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_9, [
|
||||
createVNode(unref(script$1), {
|
||||
label: _ctx.$t("g.ok"),
|
||||
icon: "pi pi-check",
|
||||
loading: isUpdating.value,
|
||||
iconPos: "right",
|
||||
onClick: updateConsent
|
||||
}, null, 8, ["label", "loading"])
|
||||
])
|
||||
])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=MetricsConsentView-C80fk2cl.js.map
|
||||
48
web/assets/NotSupportedView-Drz3x2d-.js → web/assets/NotSupportedView-B78ZVR9Z.js
generated
vendored
48
web/assets/NotSupportedView-Drz3x2d-.js → web/assets/NotSupportedView-B78ZVR9Z.js
generated
vendored
@@ -1,21 +1,16 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, bW as useRouter, r as resolveDirective, o as openBlock, k as createBlock, M as withCtx, H as createBaseVNode, X as toDisplayString, N as createVNode, j as unref, l as script, i as withDirectives } from "./index-DjNHn37O.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BNGF4K22.js";
|
||||
import { d as defineComponent, bi as useRouter, r as resolveDirective, o as openBlock, y as createBlock, z as withCtx, m as createBaseVNode, E as toDisplayString, k as createVNode, j as unref, l as script, i as withDirectives, _ as _export_sfc } from "./index-Bv0b06LE.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BTbuZf5t.js";
|
||||
const _imports_0 = "" + new URL("images/sad_girl.png", import.meta.url).href;
|
||||
const _hoisted_1 = { class: "sad-container" };
|
||||
const _hoisted_2 = /* @__PURE__ */ createBaseVNode("img", {
|
||||
class: "sad-girl",
|
||||
src: _imports_0,
|
||||
alt: "Sad girl illustration"
|
||||
}, null, -1);
|
||||
const _hoisted_3 = { class: "no-drag sad-text flex items-center" };
|
||||
const _hoisted_4 = { class: "flex flex-col gap-8 p-8 min-w-110" };
|
||||
const _hoisted_5 = { class: "text-4xl font-bold text-red-500" };
|
||||
const _hoisted_6 = { class: "space-y-4" };
|
||||
const _hoisted_7 = { class: "text-xl" };
|
||||
const _hoisted_8 = { class: "list-disc list-inside space-y-1 text-neutral-800" };
|
||||
const _hoisted_9 = { class: "flex gap-4" };
|
||||
const _hoisted_2 = { class: "no-drag sad-text flex items-center" };
|
||||
const _hoisted_3 = { class: "flex flex-col gap-8 p-8 min-w-110" };
|
||||
const _hoisted_4 = { class: "text-4xl font-bold text-red-500" };
|
||||
const _hoisted_5 = { class: "space-y-4" };
|
||||
const _hoisted_6 = { class: "text-xl" };
|
||||
const _hoisted_7 = { class: "list-disc list-inside space-y-1 text-neutral-800" };
|
||||
const _hoisted_8 = { class: "flex gap-4" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "NotSupportedView",
|
||||
setup(__props) {
|
||||
@@ -37,18 +32,22 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
return openBlock(), createBlock(_sfc_main$1, null, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
_hoisted_2,
|
||||
createBaseVNode("div", _hoisted_3, [
|
||||
createBaseVNode("div", _hoisted_4, [
|
||||
createBaseVNode("h1", _hoisted_5, toDisplayString(_ctx.$t("notSupported.title")), 1),
|
||||
createBaseVNode("div", _hoisted_6, [
|
||||
createBaseVNode("p", _hoisted_7, toDisplayString(_ctx.$t("notSupported.message")), 1),
|
||||
createBaseVNode("ul", _hoisted_8, [
|
||||
_cache[0] || (_cache[0] = createBaseVNode("img", {
|
||||
class: "sad-girl",
|
||||
src: _imports_0,
|
||||
alt: "Sad girl illustration"
|
||||
}, null, -1)),
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createBaseVNode("div", _hoisted_3, [
|
||||
createBaseVNode("h1", _hoisted_4, toDisplayString(_ctx.$t("notSupported.title")), 1),
|
||||
createBaseVNode("div", _hoisted_5, [
|
||||
createBaseVNode("p", _hoisted_6, toDisplayString(_ctx.$t("notSupported.message")), 1),
|
||||
createBaseVNode("ul", _hoisted_7, [
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("notSupported.supportedDevices.macos")), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("notSupported.supportedDevices.windows")), 1)
|
||||
])
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_9, [
|
||||
createBaseVNode("div", _hoisted_8, [
|
||||
createVNode(unref(script), {
|
||||
label: _ctx.$t("notSupported.learnMore"),
|
||||
icon: "pi pi-github",
|
||||
@@ -80,7 +79,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
const NotSupportedView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-ebb20958"]]);
|
||||
export {
|
||||
_sfc_main as default
|
||||
NotSupportedView as default
|
||||
};
|
||||
//# sourceMappingURL=NotSupportedView-Drz3x2d-.js.map
|
||||
//# sourceMappingURL=NotSupportedView-B78ZVR9Z.js.map
|
||||
8
web/assets/NotSupportedView-bFzHmqNj.css → web/assets/NotSupportedView-RFx6eCkN.css
generated
vendored
8
web/assets/NotSupportedView-bFzHmqNj.css → web/assets/NotSupportedView-RFx6eCkN.css
generated
vendored
@@ -1,17 +1,19 @@
|
||||
|
||||
.sad-container {
|
||||
&[data-v-ebb20958] {
|
||||
display: grid;
|
||||
align-items: center;
|
||||
justify-content: space-evenly;
|
||||
grid-template-columns: 25rem 1fr;
|
||||
& > * {
|
||||
}
|
||||
&[data-v-ebb20958] > * {
|
||||
grid-row: 1;
|
||||
}
|
||||
}
|
||||
.sad-text {
|
||||
.sad-text[data-v-ebb20958] {
|
||||
grid-column: 1/3;
|
||||
}
|
||||
.sad-girl {
|
||||
.sad-girl[data-v-ebb20958] {
|
||||
grid-column: 2/3;
|
||||
width: min(75vw, 100vh);
|
||||
}
|
||||
28
web/assets/ServerConfigPanel-Be4StJmv.js → web/assets/ServerConfigPanel-BYrt6wyr.js
generated
vendored
28
web/assets/ServerConfigPanel-Be4StJmv.js → web/assets/ServerConfigPanel-BYrt6wyr.js
generated
vendored
@@ -1,25 +1,23 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { H as createBaseVNode, o as openBlock, f as createElementBlock, Z as markRaw, d as defineComponent, a as useSettingStore, aS as storeToRefs, a5 as watch, cO as useCopyToClipboard, a1 as useI18n, k as createBlock, M as withCtx, j as unref, bZ as script, X as toDisplayString, E as renderList, F as Fragment, N as createVNode, l as script$1, I as createCommentVNode, bQ as script$2, cP as FormItem, cp as _sfc_main$1, c0 as electronAPI } from "./index-DjNHn37O.js";
|
||||
import { u as useServerConfigStore } from "./serverConfigStore-CvyKFVuP.js";
|
||||
import { o as openBlock, f as createElementBlock, m as createBaseVNode, H as markRaw, d as defineComponent, a as useSettingStore, af as storeToRefs, N as watch, dJ as useCopyToClipboard, I as useI18n, y as createBlock, z as withCtx, j as unref, bn as script, E as toDisplayString, D as renderList, F as Fragment, k as createVNode, l as script$1, B as createCommentVNode, bl as script$2, dK as FormItem, dz as _sfc_main$1, b9 as electronAPI } from "./index-Bv0b06LE.js";
|
||||
import { u as useServerConfigStore } from "./serverConfigStore-D2Vr0L0h.js";
|
||||
const _hoisted_1$1 = {
|
||||
viewBox: "0 0 24 24",
|
||||
width: "1.2em",
|
||||
height: "1.2em"
|
||||
};
|
||||
const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
fill: "none",
|
||||
stroke: "currentColor",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round",
|
||||
"stroke-width": "2",
|
||||
d: "m4 17l6-6l-6-6m8 14h8"
|
||||
}, null, -1);
|
||||
const _hoisted_3$1 = [
|
||||
_hoisted_2$1
|
||||
];
|
||||
function render(_ctx, _cache) {
|
||||
return openBlock(), createElementBlock("svg", _hoisted_1$1, [..._hoisted_3$1]);
|
||||
return openBlock(), createElementBlock("svg", _hoisted_1$1, _cache[0] || (_cache[0] = [
|
||||
createBaseVNode("path", {
|
||||
fill: "none",
|
||||
stroke: "currentColor",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round",
|
||||
"stroke-width": "2",
|
||||
d: "m4 17l6-6l-6-6m8 14h8"
|
||||
}, null, -1)
|
||||
]));
|
||||
}
|
||||
__name(render, "render");
|
||||
const __unplugin_components_0 = markRaw({ name: "lucide-terminal", render });
|
||||
@@ -155,4 +153,4 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ServerConfigPanel-Be4StJmv.js.map
|
||||
//# sourceMappingURL=ServerConfigPanel-BYrt6wyr.js.map
|
||||
100
web/assets/ServerStartView-B7TlHxYo.js
generated
vendored
Normal file
100
web/assets/ServerStartView-B7TlHxYo.js
generated
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, I as useI18n, T as ref, bo as ProgressStatus, p as onMounted, o as openBlock, y as createBlock, z as withCtx, m as createBaseVNode, a8 as createTextVNode, E as toDisplayString, j as unref, f as createElementBlock, B as createCommentVNode, k as createVNode, l as script, i as withDirectives, v as vShow, bp as BaseTerminal, b9 as electronAPI, _ as _export_sfc } from "./index-Bv0b06LE.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BTbuZf5t.js";
|
||||
const _hoisted_1 = { class: "flex flex-col w-full h-full items-center" };
|
||||
const _hoisted_2 = { class: "text-2xl font-bold" };
|
||||
const _hoisted_3 = { key: 0 };
|
||||
const _hoisted_4 = {
|
||||
key: 0,
|
||||
class: "flex flex-col items-center gap-4"
|
||||
};
|
||||
const _hoisted_5 = { class: "flex items-center my-4 gap-2" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ServerStartView",
|
||||
setup(__props) {
|
||||
const electron = electronAPI();
|
||||
const { t } = useI18n();
|
||||
const status = ref(ProgressStatus.INITIAL_STATE);
|
||||
const electronVersion = ref("");
|
||||
let xterm;
|
||||
const terminalVisible = ref(true);
|
||||
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
|
||||
status.value = newStatus;
|
||||
if (newStatus === ProgressStatus.ERROR) terminalVisible.value = false;
|
||||
else xterm?.clear();
|
||||
}, "updateProgress");
|
||||
const terminalCreated = /* @__PURE__ */ __name(({ terminal, useAutoSize }, root) => {
|
||||
xterm = terminal;
|
||||
useAutoSize({ root, autoRows: true, autoCols: true });
|
||||
electron.onLogMessage((message) => {
|
||||
terminal.write(message);
|
||||
});
|
||||
terminal.options.cursorBlink = false;
|
||||
terminal.options.disableStdin = true;
|
||||
terminal.options.cursorInactiveStyle = "block";
|
||||
}, "terminalCreated");
|
||||
const reinstall = /* @__PURE__ */ __name(() => electron.reinstall(), "reinstall");
|
||||
const reportIssue = /* @__PURE__ */ __name(() => {
|
||||
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
|
||||
}, "reportIssue");
|
||||
const openLogs = /* @__PURE__ */ __name(() => electron.openLogsFolder(), "openLogs");
|
||||
onMounted(async () => {
|
||||
electron.sendReady();
|
||||
electron.onProgressUpdate(updateProgress);
|
||||
electronVersion.value = await electron.getElectronVersion();
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createBlock(_sfc_main$1, {
|
||||
dark: "",
|
||||
class: "flex-col"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
createBaseVNode("h2", _hoisted_2, [
|
||||
createTextVNode(toDisplayString(unref(t)(`serverStart.process.${status.value}`)) + " ", 1),
|
||||
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("span", _hoisted_3, " v" + toDisplayString(electronVersion.value), 1)) : createCommentVNode("", true)
|
||||
]),
|
||||
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("div", _hoisted_4, [
|
||||
createBaseVNode("div", _hoisted_5, [
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-flag",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.reportIssue"),
|
||||
onClick: reportIssue
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-file",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.openLogs"),
|
||||
onClick: openLogs
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-refresh",
|
||||
label: unref(t)("serverStart.reinstall"),
|
||||
onClick: reinstall
|
||||
}, null, 8, ["label"])
|
||||
]),
|
||||
!terminalVisible.value ? (openBlock(), createBlock(unref(script), {
|
||||
key: 0,
|
||||
icon: "pi pi-search",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.showTerminal"),
|
||||
onClick: _cache[0] || (_cache[0] = ($event) => terminalVisible.value = true)
|
||||
}, null, 8, ["label"])) : createCommentVNode("", true)
|
||||
])) : createCommentVNode("", true),
|
||||
withDirectives(createVNode(BaseTerminal, { onCreated: terminalCreated }, null, 512), [
|
||||
[vShow, terminalVisible.value]
|
||||
])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
const ServerStartView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-e6ba9633"]]);
|
||||
export {
|
||||
ServerStartView as default
|
||||
};
|
||||
//# sourceMappingURL=ServerStartView-B7TlHxYo.js.map
|
||||
2
web/assets/ServerStartView-CnyN4Ib6.css → web/assets/ServerStartView-BZ7uhZHv.css
generated
vendored
2
web/assets/ServerStartView-CnyN4Ib6.css → web/assets/ServerStartView-BZ7uhZHv.css
generated
vendored
@@ -1,5 +1,5 @@
|
||||
|
||||
[data-v-42c1131d] .xterm-helper-textarea {
|
||||
[data-v-e6ba9633] .xterm-helper-textarea {
|
||||
/* Hide this as it moves all over when uv is running */
|
||||
display: none;
|
||||
}
|
||||
98
web/assets/ServerStartView-CIDTUh4x.js
generated
vendored
98
web/assets/ServerStartView-CIDTUh4x.js
generated
vendored
@@ -1,98 +0,0 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, a1 as useI18n, ab as ref, b_ as ProgressStatus, m as onMounted, o as openBlock, k as createBlock, M as withCtx, H as createBaseVNode, aE as createTextVNode, X as toDisplayString, j as unref, f as createElementBlock, I as createCommentVNode, N as createVNode, l as script, i as withDirectives, v as vShow, b$ as BaseTerminal, aL as pushScopeId, aM as popScopeId, c0 as electronAPI, _ as _export_sfc } from "./index-DjNHn37O.js";
|
||||
import { _ as _sfc_main$1 } from "./BaseViewTemplate-BNGF4K22.js";
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-42c1131d"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
const _hoisted_1 = { class: "text-2xl font-bold" };
|
||||
const _hoisted_2 = { key: 0 };
|
||||
const _hoisted_3 = {
|
||||
key: 0,
|
||||
class: "flex flex-col items-center gap-4"
|
||||
};
|
||||
const _hoisted_4 = { class: "flex items-center my-4 gap-2" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ServerStartView",
|
||||
setup(__props) {
|
||||
const electron = electronAPI();
|
||||
const { t } = useI18n();
|
||||
const status = ref(ProgressStatus.INITIAL_STATE);
|
||||
const electronVersion = ref("");
|
||||
let xterm;
|
||||
const terminalVisible = ref(true);
|
||||
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
|
||||
status.value = newStatus;
|
||||
if (newStatus === ProgressStatus.ERROR) terminalVisible.value = false;
|
||||
else xterm?.clear();
|
||||
}, "updateProgress");
|
||||
const terminalCreated = /* @__PURE__ */ __name(({ terminal, useAutoSize }, root) => {
|
||||
xterm = terminal;
|
||||
useAutoSize(root, true, true);
|
||||
electron.onLogMessage((message) => {
|
||||
terminal.write(message);
|
||||
});
|
||||
terminal.options.cursorBlink = false;
|
||||
terminal.options.disableStdin = true;
|
||||
terminal.options.cursorInactiveStyle = "block";
|
||||
}, "terminalCreated");
|
||||
const reinstall = /* @__PURE__ */ __name(() => electron.reinstall(), "reinstall");
|
||||
const reportIssue = /* @__PURE__ */ __name(() => {
|
||||
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
|
||||
}, "reportIssue");
|
||||
const openLogs = /* @__PURE__ */ __name(() => electron.openLogsFolder(), "openLogs");
|
||||
onMounted(async () => {
|
||||
electron.sendReady();
|
||||
electron.onProgressUpdate(updateProgress);
|
||||
electronVersion.value = await electron.getElectronVersion();
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createBlock(_sfc_main$1, {
|
||||
dark: "",
|
||||
class: "flex-col"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("h2", _hoisted_1, [
|
||||
createTextVNode(toDisplayString(unref(t)(`serverStart.process.${status.value}`)) + " ", 1),
|
||||
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("span", _hoisted_2, " v" + toDisplayString(electronVersion.value), 1)) : createCommentVNode("", true)
|
||||
]),
|
||||
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("div", _hoisted_3, [
|
||||
createBaseVNode("div", _hoisted_4, [
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-flag",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.reportIssue"),
|
||||
onClick: reportIssue
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-file",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.openLogs"),
|
||||
onClick: openLogs
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-refresh",
|
||||
label: unref(t)("serverStart.reinstall"),
|
||||
onClick: reinstall
|
||||
}, null, 8, ["label"])
|
||||
]),
|
||||
!terminalVisible.value ? (openBlock(), createBlock(unref(script), {
|
||||
key: 0,
|
||||
icon: "pi pi-search",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.showTerminal"),
|
||||
onClick: _cache[0] || (_cache[0] = ($event) => terminalVisible.value = true)
|
||||
}, null, 8, ["label"])) : createCommentVNode("", true)
|
||||
])) : createCommentVNode("", true),
|
||||
withDirectives(createVNode(BaseTerminal, { onCreated: terminalCreated }, null, 512), [
|
||||
[vShow, terminalVisible.value]
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
const ServerStartView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-42c1131d"]]);
|
||||
export {
|
||||
ServerStartView as default
|
||||
};
|
||||
//# sourceMappingURL=ServerStartView-CIDTUh4x.js.map
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user