Compare commits
169 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ce2a1052c | ||
|
|
f82314fcfc | ||
|
|
0075c6d096 | ||
|
|
83ca891118 | ||
|
|
f9f9faface | ||
|
|
471cd3eace | ||
|
|
a68bbafddb | ||
|
|
73e3a9e676 | ||
|
|
518c0dc2fe | ||
|
|
ce0542e10b | ||
|
|
8473019d40 | ||
|
|
89f15894dd | ||
|
|
67158994a4 | ||
|
|
7390ff3b1e | ||
|
|
0bedfb26af | ||
|
|
f71cfd2687 | ||
|
|
c695c4af7f | ||
|
|
0dbba9f751 | ||
|
|
f584758271 | ||
|
|
95b7cf9bbe | ||
|
|
191a0d56b4 | ||
|
|
3c60ecd7a8 | ||
|
|
7ae6626723 | ||
|
|
6632365e16 | ||
|
|
ad07796777 | ||
|
|
1b80895285 | ||
|
|
5f9d5a244b | ||
|
|
14eba07acd | ||
|
|
4b2f0d9413 | ||
|
|
25eac1d780 | ||
|
|
e38c94228b | ||
|
|
203942c8b2 | ||
|
|
3c72c89a52 | ||
|
|
614377abd6 | ||
|
|
8dfa0cc552 | ||
|
|
e5ecdfdd2d | ||
|
|
7d29fbf74b | ||
|
|
2c641e64ad | ||
|
|
7d2467e830 | ||
|
|
6f021d8aa0 | ||
|
|
d854ed0bcf | ||
|
|
abcd006b8c | ||
|
|
d985d1d7dc | ||
|
|
d1cdf51e1b | ||
|
|
b4626ab93e | ||
|
|
a9e459c2a4 | ||
|
|
3bb4dec720 | ||
|
|
8733191563 | ||
|
|
83b01f960a | ||
|
|
d72e871cfa | ||
|
|
037c3159b6 | ||
|
|
bdd4a22a2e | ||
|
|
fdf37566ef | ||
|
|
08c8968482 | ||
|
|
479a427a48 | ||
|
|
3a0eeee320 | ||
|
|
447da7ea86 | ||
|
|
9c41bc8d10 | ||
|
|
6ad0ddbae4 | ||
|
|
a55142f904 | ||
|
|
5718ef69bb | ||
|
|
13ecf10a92 | ||
|
|
7a415f47a9 | ||
|
|
89fa2fca24 | ||
|
|
364b69e931 | ||
|
|
dc96a1ae19 | ||
|
|
2d810b081e | ||
|
|
9f7e9f0547 | ||
|
|
a355f38ecc | ||
|
|
38c69080c7 | ||
|
|
70a708d726 | ||
|
|
e7d4782736 | ||
|
|
3326bdfd4e | ||
|
|
68bb885d22 | ||
|
|
ad66f7c7d8 | ||
|
|
de8e8e3b0d | ||
|
|
a1e71cfad1 | ||
|
|
0bfc7cc998 | ||
|
|
7183fd1665 | ||
|
|
254838f23c | ||
|
|
0b7dfa986d | ||
|
|
d514bb38ee | ||
|
|
0849c80e2a | ||
|
|
56e8f5e4fd | ||
|
|
e813abbb2c | ||
|
|
5e68a4ce67 | ||
|
|
ca08597670 | ||
|
|
f48e390032 | ||
|
|
369a6dd2c4 | ||
|
|
b3ce8fb9fd | ||
|
|
cf80d28689 | ||
|
|
6fb44c4b7c | ||
|
|
d2247c1e61 | ||
|
|
cb12ad7049 | ||
|
|
f6b7194f64 | ||
|
|
7c6eb4fb29 | ||
|
|
b962db9952 | ||
|
|
d0b7ab88ba | ||
|
|
405b529545 | ||
|
|
9d720187f1 | ||
|
|
d247bc5a9c | ||
|
|
9f4daca9d9 | ||
|
|
b5d0f2a908 | ||
|
|
e760bf5c40 | ||
|
|
36c83cdbba | ||
|
|
81778a7feb | ||
|
|
bc94662b31 | ||
|
|
9fa8faa44a | ||
|
|
9a7444e39f | ||
|
|
54fca4a218 | ||
|
|
cd4955367e | ||
|
|
8354203d95 | ||
|
|
e0b41243b4 | ||
|
|
619263d4a6 | ||
|
|
e3b0402bb7 | ||
|
|
967867d48c | ||
|
|
cbaac71bf5 | ||
|
|
3ab3516e46 | ||
|
|
9c5fca75f4 | ||
|
|
a5da4d0b3e | ||
|
|
32a60a7bac | ||
|
|
bb52934ba4 | ||
|
|
8aabd7c8c0 | ||
|
|
a09b29ca11 | ||
|
|
9bfee68773 | ||
|
|
ea77750759 | ||
|
|
c27ebeb1c2 | ||
|
|
0c7c98a965 | ||
|
|
dc2eb75b85 | ||
|
|
fa34efe3bd | ||
|
|
5cbaa9e07c | ||
|
|
c7427375ee | ||
|
|
22d1241a50 | ||
|
|
f04229b84d | ||
|
|
f067ad15d1 | ||
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 | ||
|
|
f1c2301697 | ||
|
|
8d31a6632f | ||
|
|
b643eae08b | ||
|
|
baa6b4dc36 | ||
|
|
d4aeefc297 | ||
|
|
587e7ca654 | ||
|
|
c90459eba0 | ||
|
|
04278afb10 | ||
|
|
935ae153e1 | ||
|
|
e91662e784 | ||
|
|
63fafaef45 | ||
|
|
ec28cd9136 | ||
|
|
6eb5d64522 | ||
|
|
10a79e9898 | ||
|
|
ea3f39bd69 | ||
|
|
b33cd61070 | ||
|
|
34eda0f853 | ||
|
|
d31e226650 | ||
|
|
b79fd7d92c | ||
|
|
38c22e631a | ||
|
|
6bbdcd28ae | ||
|
|
ab130001a8 | ||
|
|
ca4b8f30e0 | ||
|
|
70b84058c1 | ||
|
|
2ca8f6e23d | ||
|
|
7985ff88b9 | ||
|
|
c6812947e9 | ||
|
|
9230f65823 | ||
|
|
6ab1e6fd4a | ||
|
|
07dcbc3a3e | ||
|
|
8ae23d8e80 |
@@ -14,7 +14,7 @@ run_cpu.bat
|
|||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED WAY TO UPDATE:
|
RECOMMENDED WAY TO UPDATE:
|
||||||
|
|||||||
2
.github/workflows/pullrequest-ci-run.yml
vendored
2
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
6
.github/workflows/stable-release.yml
vendored
6
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
|||||||
description: 'CUDA version'
|
description: 'CUDA version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "121"
|
default: "124"
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'Python minor version'
|
description: 'Python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'Python patch version'
|
description: 'Python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: 'Close stale issues'
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
# Run daily at 430 am PT
|
||||||
|
- cron: '30 11 * * *'
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||||
|
days-before-stale: 30
|
||||||
|
days-before-close: 7
|
||||||
|
stale-issue-label: 'Stale'
|
||||||
|
only-labels: 'User Support'
|
||||||
|
exempt-all-assignees: true
|
||||||
|
exempt-all-milestones: true
|
||||||
4
.github/workflows/test-ci.yml
vendored
4
.github/workflows/test-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
torch_version: ["nightly"]
|
torch_version: ["nightly"]
|
||||||
include:
|
include:
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
30
.github/workflows/test-unit.yml
vendored
Normal file
30
.github/workflows/test-unit.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: Unit Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
continue-on-error: true
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Run Unit Tests
|
||||||
|
run: |
|
||||||
|
pip install -r tests-unit/requirements.txt
|
||||||
|
python -m pytest tests-unit
|
||||||
@@ -12,7 +12,7 @@ on:
|
|||||||
description: 'extra dependencies'
|
description: 'extra dependencies'
|
||||||
required: false
|
required: false
|
||||||
type: string
|
type: string
|
||||||
default: "\"numpy<2\""
|
default: ""
|
||||||
cu:
|
cu:
|
||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
@@ -23,13 +23,13 @@ on:
|
|||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ on:
|
|||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ extra_model_paths.yaml
|
|||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv/
|
venv/
|
||||||
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
!/web/extensions/core/
|
!/web/extensions/core/
|
||||||
|
|||||||
16
README.md
16
README.md
@@ -1,7 +1,7 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
# ComfyUI
|
# ComfyUI
|
||||||
**The most powerful and modular stable diffusion GUI and backend.**
|
**The most powerful and modular diffusion model GUI and backend.**
|
||||||
|
|
||||||
|
|
||||||
[![Website][website-shield]][website-url]
|
[![Website][website-shield]][website-url]
|
||||||
@@ -94,6 +94,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| Alt + `+` | Canvas Zoom in |
|
| Alt + `+` | Canvas Zoom in |
|
||||||
| Alt + `-` | Canvas Zoom out |
|
| Alt + `-` | Canvas Zoom out |
|
||||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||||
|
| P | Pin/Unpin selected nodes |
|
||||||
|
| Ctrl + G | Group selected nodes |
|
||||||
| Q | Toggle visibility of the queue |
|
| Q | Toggle visibility of the queue |
|
||||||
| H | Toggle visibility of history |
|
| H | Toggle visibility of history |
|
||||||
| R | Refresh graph |
|
| R | Refresh graph |
|
||||||
@@ -125,6 +127,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
|
|||||||
|
|
||||||
## Manual Install (Windows, Linux)
|
## Manual Install (Windows, Linux)
|
||||||
|
|
||||||
|
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|
||||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||||
@@ -135,17 +139,17 @@ Put your VAE in: models/vae
|
|||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
@@ -230,7 +234,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||||
|
|
||||||
## How to use TLS/SSL?
|
## How to use TLS/SSL?
|
||||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from folder_paths import models_dir, user_directory, output_directory
|
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||||
from api_server.services.file_service import FileService
|
from api_server.services.file_service import FileService
|
||||||
|
import app.logger
|
||||||
|
|
||||||
class InternalRoutes:
|
class InternalRoutes:
|
||||||
'''
|
'''
|
||||||
@@ -31,6 +32,16 @@ class InternalRoutes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
@self.routes.get('/logs')
|
||||||
|
async def get_logs(request):
|
||||||
|
return web.json_response(app.logger.get_logs())
|
||||||
|
|
||||||
|
@self.routes.get('/folder_paths')
|
||||||
|
async def get_folder_paths(request):
|
||||||
|
response = {}
|
||||||
|
for key in folder_names_and_paths:
|
||||||
|
response[key] = folder_names_and_paths[key][0]
|
||||||
|
return web.json_response(response)
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import zipfile
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
@@ -132,12 +132,13 @@ class FrontendManager:
|
|||||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Initializes the frontend for the specified version.
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version_string (str): The version string.
|
version_string (str): The version string.
|
||||||
|
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the initialized frontend.
|
str: The path to the initialized frontend.
|
||||||
@@ -150,7 +151,16 @@ class FrontendManager:
|
|||||||
return cls.DEFAULT_FRONTEND_PATH
|
return cls.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
provider = FrontEndProvider(repo_owner, repo_name)
|
|
||||||
|
if version.startswith("v"):
|
||||||
|
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||||
|
if os.path.exists(expected_path):
|
||||||
|
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||||
|
return expected_path
|
||||||
|
|
||||||
|
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||||
|
|
||||||
|
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||||
release = provider.get_release(version)
|
release = provider.get_release(version)
|
||||||
|
|
||||||
semantic_version = release["tag_name"].lstrip("v")
|
semantic_version = release["tag_name"].lstrip("v")
|
||||||
@@ -158,15 +168,25 @@ class FrontendManager:
|
|||||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||||
)
|
)
|
||||||
if not os.path.exists(web_root):
|
if not os.path.exists(web_root):
|
||||||
os.makedirs(web_root, exist_ok=True)
|
# Use tmp path until complete to avoid path exists check passing from interrupted downloads
|
||||||
logging.info(
|
tmp_path = web_root + ".tmp"
|
||||||
"Downloading frontend(%s) version(%s) to (%s)",
|
try:
|
||||||
provider.folder_name,
|
os.makedirs(tmp_path, exist_ok=True)
|
||||||
semantic_version,
|
logging.info(
|
||||||
web_root,
|
"Downloading frontend(%s) version(%s) to (%s)",
|
||||||
)
|
provider.folder_name,
|
||||||
logging.debug(release)
|
semantic_version,
|
||||||
download_release_asset_zip(release, destination_path=web_root)
|
tmp_path,
|
||||||
|
)
|
||||||
|
logging.debug(release)
|
||||||
|
download_release_asset_zip(release, destination_path=tmp_path)
|
||||||
|
if os.listdir(tmp_path):
|
||||||
|
os.rename(tmp_path, web_root)
|
||||||
|
finally:
|
||||||
|
# Clean up the directory if it is empty, i.e. the download failed
|
||||||
|
if not os.listdir(web_root):
|
||||||
|
os.rmdir(web_root)
|
||||||
|
|
||||||
return web_root
|
return web_root
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
31
app/logger.py
Normal file
31
app/logger.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
from logging.handlers import MemoryHandler
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
logs = None
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logs():
|
||||||
|
return "\n".join([formatter.format(x) for x in logs])
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||||
|
global logs
|
||||||
|
if logs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup default global logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# Create a memory handler with a deque as its buffer
|
||||||
|
logs = deque(maxlen=capacity)
|
||||||
|
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||||
|
memory_handler.buffer = logs
|
||||||
|
memory_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(memory_handler)
|
||||||
@@ -5,17 +5,17 @@ import uuid
|
|||||||
import glob
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from urllib import parse
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from folder_paths import user_directory
|
import folder_paths
|
||||||
from .app_settings import AppSettings
|
from .app_settings import AppSettings
|
||||||
|
|
||||||
default_user = "default"
|
default_user = "default"
|
||||||
users_file = os.path.join(user_directory, "users.json")
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager():
|
class UserManager():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
self.settings = AppSettings(self)
|
self.settings = AppSettings(self)
|
||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
@@ -25,14 +25,17 @@ class UserManager():
|
|||||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(users_file):
|
if os.path.isfile(self.get_users_file()):
|
||||||
with open(users_file) as f:
|
with open(self.get_users_file()) as f:
|
||||||
self.users = json.load(f)
|
self.users = json.load(f)
|
||||||
else:
|
else:
|
||||||
self.users = {}
|
self.users = {}
|
||||||
else:
|
else:
|
||||||
self.users = {"default": "default"}
|
self.users = {"default": "default"}
|
||||||
|
|
||||||
|
def get_users_file(self):
|
||||||
|
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||||
|
|
||||||
def get_request_user_id(self, request):
|
def get_request_user_id(self, request):
|
||||||
user = "default"
|
user = "default"
|
||||||
if args.multi_user and "comfy-user" in request.headers:
|
if args.multi_user and "comfy-user" in request.headers:
|
||||||
@@ -44,7 +47,7 @@ class UserManager():
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
if type == "userdata":
|
if type == "userdata":
|
||||||
root_dir = user_directory
|
root_dir = user_directory
|
||||||
@@ -59,6 +62,10 @@ class UserManager():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if file is not None:
|
if file is not None:
|
||||||
|
# Check if filename is url encoded
|
||||||
|
if "%" in file:
|
||||||
|
file = parse.unquote(file)
|
||||||
|
|
||||||
# prevent leaving /{type}/{user}
|
# prevent leaving /{type}/{user}
|
||||||
path = os.path.abspath(os.path.join(user_root, file))
|
path = os.path.abspath(os.path.join(user_root, file))
|
||||||
if os.path.commonpath((user_root, path)) != user_root:
|
if os.path.commonpath((user_root, path)) != user_root:
|
||||||
@@ -80,8 +87,7 @@ class UserManager():
|
|||||||
|
|
||||||
self.users[user_id] = name
|
self.users[user_id] = name
|
||||||
|
|
||||||
global users_file
|
with open(self.get_users_file(), "w") as f:
|
||||||
with open(users_file, "w") as f:
|
|
||||||
json.dump(self.users, f)
|
json.dump(self.users, f)
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
@@ -112,25 +118,69 @@ class UserManager():
|
|||||||
|
|
||||||
@routes.get("/userdata")
|
@routes.get("/userdata")
|
||||||
async def listuserdata(request):
|
async def listuserdata(request):
|
||||||
|
"""
|
||||||
|
List user data files in a specified directory.
|
||||||
|
|
||||||
|
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||||
|
full file information, and path splitting.
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
- dir (required): The directory to list files from.
|
||||||
|
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||||
|
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||||
|
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- 400: If 'dir' parameter is missing.
|
||||||
|
- 403: If the requested path is not allowed.
|
||||||
|
- 404: If the requested directory does not exist.
|
||||||
|
- 200: JSON response with the list of files or file information.
|
||||||
|
|
||||||
|
The response format depends on the query parameters:
|
||||||
|
- Default: List of relative file paths.
|
||||||
|
- full_info=true: List of dictionaries with file details.
|
||||||
|
- split=true (and full_info=false): List of lists, each containing path components.
|
||||||
|
"""
|
||||||
directory = request.rel_url.query.get('dir', '')
|
directory = request.rel_url.query.get('dir', '')
|
||||||
if not directory:
|
if not directory:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400, text="Directory not provided")
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, directory)
|
path = self.get_request_user_filepath(request, directory)
|
||||||
if not path:
|
if not path:
|
||||||
return web.Response(status=403)
|
return web.Response(status=403, text="Invalid directory")
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404, text="Directory not found")
|
||||||
|
|
||||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||||
results = glob.glob(os.path.join(
|
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||||
glob.escape(path), '**/*'), recursive=recurse)
|
|
||||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
# Use different patterns based on whether we're recursing or not
|
||||||
|
if recurse:
|
||||||
|
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||||
|
else:
|
||||||
|
pattern = os.path.join(glob.escape(path), '*')
|
||||||
|
|
||||||
|
results = glob.glob(pattern, recursive=recurse)
|
||||||
|
|
||||||
|
if full_info:
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||||
|
'size': os.path.getsize(x),
|
||||||
|
'modified': os.path.getmtime(x)
|
||||||
|
} for x in results if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
results = [
|
||||||
|
os.path.relpath(x, path).replace(os.sep, '/')
|
||||||
|
for x in results
|
||||||
|
if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
|
||||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||||
if split_path:
|
if split_path and not full_info:
|
||||||
results = [[x] + x.split(os.sep) for x in results]
|
results = [[x] + x.split('/') for x in results]
|
||||||
|
|
||||||
return web.json_response(results)
|
return web.json_response(results)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
|
control_latent_channels = None,
|
||||||
dtype = None,
|
dtype = None,
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
for _ in range(len(self.joint_blocks)):
|
for _ in range(len(self.joint_blocks)):
|
||||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
|
||||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||||
None,
|
None,
|
||||||
self.patch_size,
|
self.patch_size,
|
||||||
self.in_channels,
|
control_latent_channels,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
strict_img_size=False,
|
strict_img_size=False,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||||
@@ -92,6 +92,8 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|
||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
@@ -134,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
|||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
@@ -169,6 +171,8 @@ parser.add_argument(
|
|||||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
@@ -179,10 +183,3 @@ if args.windows_standalone_build:
|
|||||||
|
|
||||||
if args.disable_auto_launch:
|
if args.disable_auto_launch:
|
||||||
args.auto_launch = False
|
args.auto_launch = False
|
||||||
|
|
||||||
import logging
|
|
||||||
logging_level = logging.INFO
|
|
||||||
if args.verbose:
|
|
||||||
logging_level = logging.DEBUG
|
|
||||||
|
|
||||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
|
||||||
|
|||||||
@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if k not in u:
|
if k not in u:
|
||||||
t = sd.pop(k)
|
sd.pop(k)
|
||||||
del t
|
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load(ckpt_path):
|
def load(ckpt_path):
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ import comfy.t2i_adapter.adapter
|
|||||||
import comfy.ldm.cascade.controlnet
|
import comfy.ldm.cascade.controlnet
|
||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet_xlabs
|
import comfy.ldm.flux.controlnet
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
@@ -79,13 +79,21 @@ class ControlBase:
|
|||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
|
self.concat_mask = False
|
||||||
|
self.extra_concat_orig = []
|
||||||
|
self.extra_concat = None
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.timestep_percent_range = timestep_percent_range
|
self.timestep_percent_range = timestep_percent_range
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
|
if vae is None:
|
||||||
|
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
|
self.extra_concat_orig = extra_concat.copy()
|
||||||
|
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||||
|
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
@@ -100,9 +108,9 @@ class ControlBase:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
self.cond_hint = None
|
||||||
self.cond_hint = None
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
@@ -123,6 +131,8 @@ class ControlBase:
|
|||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
c.extra_conds = self.extra_conds.copy()
|
c.extra_conds = self.extra_conds.copy()
|
||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
|
c.concat_mask = self.concat_mask
|
||||||
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@@ -148,7 +158,7 @@ class ControlBase:
|
|||||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||||
x *= (self.strength ** float(len(control_output) - i))
|
x *= (self.strength ** float(len(control_output) - i))
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if output_dtype is not None and x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
out[key].append(x)
|
out[key].append(x)
|
||||||
@@ -175,7 +185,7 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@@ -189,6 +199,7 @@ class ControlNet(ControlBase):
|
|||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
self.extra_conds += extra_conds
|
self.extra_conds += extra_conds
|
||||||
self.strength_type = strength_type
|
self.strength_type = strength_type
|
||||||
|
self.concat_mask = concat_mask
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@@ -206,7 +217,6 @@ class ControlNet(ControlBase):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
@@ -214,6 +224,9 @@ class ControlNet(ControlBase):
|
|||||||
compression_ratio = self.compression_ratio
|
compression_ratio = self.compression_ratio
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
compression_ratio *= self.vae.downscale_ratio
|
compression_ratio *= self.vae.downscale_ratio
|
||||||
|
else:
|
||||||
|
if self.latent_format is not None:
|
||||||
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
@@ -221,6 +234,14 @@ class ControlNet(ControlBase):
|
|||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||||
|
if len(self.extra_concat_orig) > 0:
|
||||||
|
to_concat = []
|
||||||
|
for c in self.extra_concat_orig:
|
||||||
|
c = c.to(self.cond_hint.device)
|
||||||
|
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||||
|
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||||
|
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||||
|
|
||||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
@@ -236,7 +257,7 @@ class ControlNet(ControlBase):
|
|||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype=None)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
@@ -320,7 +341,7 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
|
||||||
ControlBase.__init__(self, device)
|
ControlBase.__init__(self, device)
|
||||||
self.control_weights = control_weights
|
self.control_weights = control_weights
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
@@ -377,21 +398,28 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def controlnet_config(sd):
|
def controlnet_config(sd, model_options={}):
|
||||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
controlnet_config = model_config.unet_config
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
@@ -403,26 +431,31 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
return control_model
|
return control_model
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
concat_mask = False
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||||
|
if control_latent_channels == 17: #inpaint controlnet
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data):
|
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SDXL()
|
latent_format = comfy.latent_formats.SDXL()
|
||||||
@@ -430,22 +463,49 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||||
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
num_union_modes = 0
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
union_cnet = "controlnet_mode_embedder.weight"
|
||||||
|
if union_cnet in new_sd:
|
||||||
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||||
|
concat_mask = False
|
||||||
|
if control_latent_channels == 17:
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Flux()
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def convert_mistoline(sd):
|
||||||
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||||
|
controlnet_data = state_dict
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data)
|
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
supported_inference_dtypes = None
|
supported_inference_dtypes = None
|
||||||
@@ -500,11 +560,15 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||||
else:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@@ -516,26 +580,38 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
elif key in controlnet_data:
|
elif key in controlnet_data:
|
||||||
prefix = ""
|
prefix = ""
|
||||||
else:
|
else:
|
||||||
net = load_t2i_adapter(controlnet_data)
|
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||||
if net is None:
|
if net is None:
|
||||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
logging.error("error could not detect control model type.")
|
||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
controlnet_config = model_config.unet_config
|
controlnet_config = model_config.unet_config
|
||||||
|
|
||||||
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||||
|
|
||||||
|
if supported_inference_dtypes is None:
|
||||||
|
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||||
|
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
if supported_inference_dtypes is None:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
|
||||||
else:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
operations = model_options.get("custom_operations", None)
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
controlnet_config["operations"] = operations
|
||||||
controlnet_config["dtype"] = unet_dtype
|
controlnet_config["dtype"] = unet_dtype
|
||||||
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
@@ -569,14 +645,21 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
|
||||||
global_average_pooling = True
|
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
if "global_average_pooling" not in model_options:
|
||||||
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
model_options["global_average_pooling"] = True
|
||||||
|
|
||||||
|
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||||
|
if cnet is None:
|
||||||
|
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||||
|
return cnet
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
@@ -632,7 +715,7 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,17 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||||
|
mantissa_scaled = torch.where(
|
||||||
|
normal_mask,
|
||||||
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||||
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||||
|
)
|
||||||
|
|
||||||
|
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||||
|
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||||
|
|
||||||
#Not 100% sure about this
|
#Not 100% sure about this
|
||||||
def manual_stochastic_round_to_float8(x, dtype):
|
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn:
|
||||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||||
elif dtype == torch.float8_e5m2:
|
elif dtype == torch.float8_e5m2:
|
||||||
@@ -9,44 +19,35 @@ def manual_stochastic_round_to_float8(x, dtype):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported dtype")
|
raise ValueError("Unsupported dtype")
|
||||||
|
|
||||||
|
x = x.half()
|
||||||
sign = torch.sign(x)
|
sign = torch.sign(x)
|
||||||
abs_x = x.abs()
|
abs_x = x.abs()
|
||||||
|
sign = torch.where(abs_x == 0, 0, sign)
|
||||||
|
|
||||||
# Combine exponent calculation and clamping
|
# Combine exponent calculation and clamping
|
||||||
exponent = torch.clamp(
|
exponent = torch.clamp(
|
||||||
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
|
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||||
0, 2**EXPONENT_BITS - 1
|
0, 2**EXPONENT_BITS - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine mantissa calculation and rounding
|
# Combine mantissa calculation and rounding
|
||||||
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
|
||||||
# zero_mask = (abs_x == 0)
|
|
||||||
# subnormal_mask = (exponent == 0) & (abs_x != 0)
|
|
||||||
normal_mask = ~(exponent == 0)
|
normal_mask = ~(exponent == 0)
|
||||||
|
|
||||||
mantissa_scaled = torch.where(
|
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||||
|
|
||||||
|
sign *= torch.where(
|
||||||
normal_mask,
|
normal_mask,
|
||||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
)
|
|
||||||
mantissa_floor = mantissa_scaled.floor()
|
|
||||||
mantissa = torch.where(
|
|
||||||
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
|
||||||
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
|
||||||
mantissa_floor / (2**MANTISSA_BITS)
|
|
||||||
)
|
|
||||||
result = torch.where(
|
|
||||||
normal_mask,
|
|
||||||
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
|
|
||||||
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = torch.where(abs_x == 0, 0, result)
|
inf = torch.finfo(dtype)
|
||||||
return result.to(dtype=dtype)
|
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
||||||
|
return sign
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def stochastic_rounding(value, dtype):
|
def stochastic_rounding(value, dtype, seed=0):
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
return value.to(dtype=torch.float32)
|
return value.to(dtype=torch.float32)
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16:
|
||||||
@@ -54,6 +55,13 @@ def stochastic_rounding(value, dtype):
|
|||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return value.to(dtype=torch.bfloat16)
|
return value.to(dtype=torch.bfloat16)
|
||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
return manual_stochastic_round_to_float8(value, dtype)
|
generator = torch.Generator(device=value.device)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
for i in range(0, value.shape[0], slice_size):
|
||||||
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||||
|
return output
|
||||||
|
|
||||||
return value.to(dtype=dtype)
|
return value.to(dtype=dtype)
|
||||||
|
|||||||
@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
|||||||
return append_zero(sigmas)
|
return append_zero(sigmas)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||||
|
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||||
|
epsilon = 1e-5 # avoid log(0)
|
||||||
|
x = torch.linspace(0, 1, n, device=device)
|
||||||
|
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||||
|
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||||
|
sigmas = clamp(torch.exp(lmb))
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def to_d(x, sigma, denoised):
|
def to_d(x, sigma, denoised):
|
||||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||||
@@ -1069,7 +1080,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
d = to_d(x, sigma_hat, temp[0])
|
d = to_d(x, sigma_hat, temp[0])
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||||
dt = sigmas[i + 1] - sigma_hat
|
|
||||||
# Euler method
|
# Euler method
|
||||||
x = denoised + d * sigmas[i + 1]
|
x = denoised + d * sigmas[i + 1]
|
||||||
return x
|
return x
|
||||||
@@ -1096,8 +1106,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
d = to_d(x, sigmas[i], temp[0])
|
d = to_d(x, sigmas[i], temp[0])
|
||||||
# Euler method
|
# Euler method
|
||||||
dt = sigma_down - sigmas[i]
|
|
||||||
x = denoised + d * sigma_down
|
x = denoised + d * sigma_down
|
||||||
if sigmas[i + 1] > 0:
|
if sigmas[i + 1] > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
|
temp = [0]
|
||||||
|
def post_cfg_function(args):
|
||||||
|
temp[0] = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigma_down == 0:
|
||||||
|
# Euler method
|
||||||
|
d = to_d(x, sigmas[i], temp[0])
|
||||||
|
x = denoised + d * sigma_down
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2S)
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||||
|
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||||
|
r = 1 / 2
|
||||||
|
h = t_next - t
|
||||||
|
s = t + r * h
|
||||||
|
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||||
|
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||||
|
# Noise addition
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""DPM-Solver++(2M)."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
old_uncond_denoised = None
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||||
|
h = t_next - t
|
||||||
|
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||||
|
else:
|
||||||
|
h_last = t - t_fn(sigmas[i - 1])
|
||||||
|
r = h_last / h
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||||
|
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||||
|
old_uncond_denoised = uncond_denoised
|
||||||
|
return x
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ class LatentFormat:
|
|||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
|
latent_rgb_factors_bias = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.3920, 0.4054, 0.4549],
|
[ 0.3651, 0.4232, 0.4341],
|
||||||
[-0.2634, -0.0196, 0.0653],
|
[-0.2533, -0.0042, 0.1068],
|
||||||
[ 0.0568, 0.1687, -0.0755],
|
[ 0.1076, 0.1111, -0.0362],
|
||||||
[-0.3112, -0.2359, -0.2076]
|
[-0.3165, -0.2492, -0.2188]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
|
||||||
|
|
||||||
self.taesd_decoder_name = "taesdxl_decoder"
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
|
||||||
class SDXL_Playground_2_5(LatentFormat):
|
class SDXL_Playground_2_5(LatentFormat):
|
||||||
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
|
|||||||
self.scale_factor = 1.5305
|
self.scale_factor = 1.5305
|
||||||
self.shift_factor = 0.0609
|
self.shift_factor = 0.0609
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
[-0.0645, 0.0177, 0.1052],
|
[-0.0922, -0.0175, 0.0749],
|
||||||
[ 0.0028, 0.0312, 0.0650],
|
[ 0.0311, 0.0633, 0.0954],
|
||||||
[ 0.1848, 0.0762, 0.0360],
|
[ 0.1994, 0.0927, 0.0458],
|
||||||
[ 0.0944, 0.0360, 0.0889],
|
[ 0.0856, 0.0339, 0.0902],
|
||||||
[ 0.0897, 0.0506, -0.0364],
|
[ 0.0587, 0.0272, -0.0496],
|
||||||
[-0.0020, 0.1203, 0.0284],
|
[-0.0006, 0.1104, 0.0309],
|
||||||
[ 0.0855, 0.0118, 0.0283],
|
[ 0.0978, 0.0306, 0.0427],
|
||||||
[-0.0539, 0.0658, 0.1047],
|
[-0.0042, 0.1038, 0.1358],
|
||||||
[-0.0057, 0.0116, 0.0700],
|
[-0.0194, 0.0020, 0.0669],
|
||||||
[-0.0412, 0.0281, -0.0039],
|
[-0.0488, 0.0130, -0.0268],
|
||||||
[ 0.1106, 0.1171, 0.1220],
|
[ 0.0922, 0.0988, 0.0951],
|
||||||
[-0.0248, 0.0682, -0.0481],
|
[-0.0278, 0.0524, -0.0542],
|
||||||
[ 0.0815, 0.0846, 0.1207],
|
[ 0.0332, 0.0456, 0.0895],
|
||||||
[-0.0120, -0.0055, -0.0867],
|
[-0.0069, -0.0030, -0.0810],
|
||||||
[-0.0749, -0.0634, -0.0456],
|
[-0.0596, -0.0465, -0.0293],
|
||||||
[-0.1418, -0.1457, -0.1259]
|
[-0.1448, -0.1463, -0.1189]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
|
||||||
self.taesd_decoder_name = "taesd3_decoder"
|
self.taesd_decoder_name = "taesd3_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -146,23 +150,24 @@ class Flux(SD3):
|
|||||||
self.scale_factor = 0.3611
|
self.scale_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
self.latent_rgb_factors =[
|
self.latent_rgb_factors =[
|
||||||
[-0.0404, 0.0159, 0.0609],
|
[-0.0346, 0.0244, 0.0681],
|
||||||
[ 0.0043, 0.0298, 0.0850],
|
[ 0.0034, 0.0210, 0.0687],
|
||||||
[ 0.0328, -0.0749, -0.0503],
|
[ 0.0275, -0.0668, -0.0433],
|
||||||
[-0.0245, 0.0085, 0.0549],
|
[-0.0174, 0.0160, 0.0617],
|
||||||
[ 0.0966, 0.0894, 0.0530],
|
[ 0.0859, 0.0721, 0.0329],
|
||||||
[ 0.0035, 0.0399, 0.0123],
|
[ 0.0004, 0.0383, 0.0115],
|
||||||
[ 0.0583, 0.1184, 0.1262],
|
[ 0.0405, 0.0861, 0.0915],
|
||||||
[-0.0191, -0.0206, -0.0306],
|
[-0.0236, -0.0185, -0.0259],
|
||||||
[-0.0324, 0.0055, 0.1001],
|
[-0.0245, 0.0250, 0.1180],
|
||||||
[ 0.0955, 0.0659, -0.0545],
|
[ 0.1008, 0.0755, -0.0421],
|
||||||
[-0.0504, 0.0231, -0.0013],
|
[-0.0515, 0.0201, 0.0011],
|
||||||
[ 0.0500, -0.0008, -0.0088],
|
[ 0.0428, -0.0012, -0.0036],
|
||||||
[ 0.0982, 0.0941, 0.0976],
|
[ 0.0817, 0.0765, 0.0749],
|
||||||
[-0.1233, -0.0280, -0.0897],
|
[-0.1264, -0.0522, -0.1103],
|
||||||
[-0.0005, -0.0530, -0.0020],
|
[-0.0280, -0.0881, -0.0499],
|
||||||
[-0.1273, -0.0932, -0.0680]
|
[-0.1262, -0.0982, -0.0778]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||||
self.taesd_decoder_name = "taef1_decoder"
|
self.taesd_decoder_name = "taef1_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
@@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
def rms_norm(x, weight, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|||||||
205
comfy/ldm/flux/controlnet.py
Normal file
205
comfy/ldm/flux/controlnet.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
#modified to support different types of flux controlnets
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||||
|
MLPEmbedder, SingleStreamBlock,
|
||||||
|
timestep_embedding)
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class MistolineCondDownsamplBlock(nn.Module):
|
||||||
|
def __init__(self, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
class MistolineControlnetBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFlux(Flux):
|
||||||
|
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
self.main_model_double = 19
|
||||||
|
self.main_model_single = 38
|
||||||
|
|
||||||
|
self.mistoline = mistoline
|
||||||
|
# add ControlNet blocks
|
||||||
|
if self.mistoline:
|
||||||
|
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth):
|
||||||
|
self.controlnet_blocks.append(control_block())
|
||||||
|
|
||||||
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth_single_blocks):
|
||||||
|
self.controlnet_single_blocks.append(control_block())
|
||||||
|
|
||||||
|
self.num_union_modes = num_union_modes
|
||||||
|
self.controlnet_mode_embedder = None
|
||||||
|
if self.num_union_modes > 0:
|
||||||
|
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.latent_input = latent_input
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
else:
|
||||||
|
control_latent_channels *= 2 * 2 #patch size
|
||||||
|
|
||||||
|
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
if not self.latent_input:
|
||||||
|
if self.mistoline:
|
||||||
|
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.input_hint_block = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
control_type: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
|
||||||
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
|
img = img + controlnet_cond
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||||
|
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||||
|
txt = torch.cat([control_cond, txt], dim=1)
|
||||||
|
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
controlnet_double = ()
|
||||||
|
|
||||||
|
for i in range(len(self.double_blocks)):
|
||||||
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
controlnet_single = ()
|
||||||
|
|
||||||
|
for i in range(len(self.single_blocks)):
|
||||||
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||||
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||||
|
if self.latent_input:
|
||||||
|
out_input = ()
|
||||||
|
for x in controlnet_double:
|
||||||
|
out_input += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_input = (controlnet_double * repeat)
|
||||||
|
|
||||||
|
out = {"input": out_input[:self.main_model_double]}
|
||||||
|
if len(controlnet_single) > 0:
|
||||||
|
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||||
|
out_output = ()
|
||||||
|
if self.latent_input:
|
||||||
|
for x in controlnet_single:
|
||||||
|
out_output += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_output = (controlnet_single * repeat)
|
||||||
|
out["output"] = out_output[:self.main_model_single]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
|
patch_size = 2
|
||||||
|
if self.latent_input:
|
||||||
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
|
elif self.mistoline:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_cond_block(hint)
|
||||||
|
else:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_hint_block(hint)
|
||||||
|
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
||||||
MLPEmbedder, SingleStreamBlock,
|
|
||||||
timestep_embedding)
|
|
||||||
|
|
||||||
from .model import Flux
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
||||||
|
|
||||||
# add ControlNet blocks
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(self.params.depth):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
# controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
self.input_hint_block = nn.Sequential(
|
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
img: Tensor,
|
|
||||||
img_ids: Tensor,
|
|
||||||
controlnet_cond: Tensor,
|
|
||||||
txt: Tensor,
|
|
||||||
txt_ids: Tensor,
|
|
||||||
timesteps: Tensor,
|
|
||||||
y: Tensor,
|
|
||||||
guidance: Tensor = None,
|
|
||||||
) -> Tensor:
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
# running on sequences img
|
|
||||||
img = self.img_in(img)
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
|
||||||
img = img + controlnet_cond
|
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
||||||
if self.params.guidance_embed:
|
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
||||||
vec = vec + self.vector_in(y)
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
pe = self.pe_embedder(ids)
|
|
||||||
|
|
||||||
block_res_samples = ()
|
|
||||||
|
|
||||||
for block in self.double_blocks:
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
||||||
block_res_samples = block_res_samples + (img,)
|
|
||||||
|
|
||||||
controlnet_block_res_samples = ()
|
|
||||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
||||||
block_res_sample = controlnet_block(block_res_sample)
|
|
||||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
||||||
|
|
||||||
return {"input": (controlnet_block_res_samples * 10)[:19]}
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
|
||||||
hint = hint * 2.0 - 1.0
|
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
patch_size = 2
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
||||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
|
||||||
@@ -6,6 +6,7 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
x_dtype = x.dtype
|
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||||
x = x.float()
|
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
|
||||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class Flux(nn.Module):
|
|||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y)
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
@@ -151,8 +151,8 @@ class Flux(nn.Module):
|
|||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
|||||||
@@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
if layer > self.depth // 2:
|
if layer > self.depth // 2:
|
||||||
if controls is not None:
|
if controls is not None:
|
||||||
skip = skips.pop() + controls.pop()
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
|||||||
@@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
"""
|
|
||||||
Apply the RMSNorm normalization to the input tensor.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
Forward pass through the RMSNorm layer.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The output tensor after applying RMSNorm.
|
|
||||||
"""
|
|
||||||
x = self._norm(x)
|
|
||||||
if self.learnable_scale:
|
|
||||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUFeedForward(nn.Module):
|
class SwiGLUFeedForward(nn.Module):
|
||||||
|
|||||||
@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
|
|||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
if "emb_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["emb_patch"]
|
||||||
|
for p in patch:
|
||||||
|
emb = p(emb, self.model_channels, transformer_options)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|||||||
115
comfy/lora.py
115
comfy/lora.py
@@ -16,6 +16,7 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
@@ -200,9 +201,13 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
for k in sdk:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
|
clip_g_present = False
|
||||||
for b in range(32): #TODO: clean up
|
for b in range(32): #TODO: clean up
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
@@ -226,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
|
clip_g_present = True
|
||||||
if clip_l_present:
|
if clip_l_present:
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
@@ -241,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
t5_index = 1
|
||||||
key_map[lora_key] = k
|
if clip_g_present:
|
||||||
|
t5_index += 1
|
||||||
|
if clip_l_present:
|
||||||
|
t5_index += 1
|
||||||
|
if t5_index == 2:
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||||
|
t5_index += 1
|
||||||
|
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
@@ -280,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||||
|
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||||
|
|
||||||
diffusers_lora_prefix = ["", "unet."]
|
diffusers_lora_prefix = ["", "unet."]
|
||||||
for p in diffusers_lora_prefix:
|
for p in diffusers_lora_prefix:
|
||||||
@@ -323,14 +338,15 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
weight_calc.transpose(0, 1)
|
weight_calc.transpose(0, 1)
|
||||||
.reshape(weight_calc.shape[1], -1)
|
.reshape(weight_calc.shape[1], -1)
|
||||||
@@ -347,6 +363,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
|||||||
weight[:] = weight_calc
|
weight[:] = weight_calc
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
@@ -366,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
@@ -375,12 +424,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
v = v[1]
|
v = v[1]
|
||||||
|
|
||||||
if patch_type == "diff":
|
if patch_type == "diff":
|
||||||
w1 = v[0]
|
diff: torch.Tensor = v[0]
|
||||||
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||||
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||||
|
if do_pad_weight and diff.shape != weight.shape:
|
||||||
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||||
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||||
|
|
||||||
if strength != 0.0:
|
if strength != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if diff.shape != weight.shape:
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
@@ -398,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -444,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -481,28 +536,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -96,7 +96,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||||
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@@ -244,6 +245,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
|
||||||
|
if self.model_config.scaled_fp8 is not None:
|
||||||
|
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
|
|||||||
@@ -286,9 +286,15 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
return None
|
return None
|
||||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||||
else:
|
|
||||||
return model_config
|
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
|
||||||
|
if scaled_fp8_weight is not None:
|
||||||
|
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||||
|
if model_config.scaled_fp8 == torch.float32:
|
||||||
|
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
return model_config
|
||||||
|
|
||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ cpu_state = CPUState.GPU
|
|||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
torch_version = ""
|
||||||
try:
|
try:
|
||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||||
@@ -144,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
|||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
logging.info("pytorch version: {}".format(torch_version))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -325,7 +326,7 @@ class LoadedModel:
|
|||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
@@ -369,12 +370,11 @@ def offloaded_memory(loaded_models, device):
|
|||||||
offloaded_mem += m.model_offloaded_memory()
|
offloaded_mem += m.model_offloaded_memory()
|
||||||
return offloaded_mem
|
return offloaded_mem
|
||||||
|
|
||||||
def minimum_inference_memory():
|
WINDOWS = any(platform.win32_ver())
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
|
||||||
|
|
||||||
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if any(platform.win32_ver()):
|
if WINDOWS:
|
||||||
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
@@ -383,6 +383,9 @@ if args.reserve_vram is not None:
|
|||||||
def extra_reserved_memory():
|
def extra_reserved_memory():
|
||||||
return EXTRA_RESERVED_VRAM
|
return EXTRA_RESERVED_VRAM
|
||||||
|
|
||||||
|
def minimum_inference_memory():
|
||||||
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||||
to_unload = []
|
to_unload = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
@@ -405,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
|||||||
if not force_unload:
|
if not force_unload:
|
||||||
if unload_weights_only and unload_weight == False:
|
if unload_weights_only and unload_weight == False:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
unload_weight = True
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||||
@@ -421,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model.device == device:
|
if shift_model.device == device:
|
||||||
if shift_model not in keep_loaded:
|
if shift_model not in keep_loaded:
|
||||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
@@ -621,6 +626,8 @@ def maximum_vram_for_weights(device=None):
|
|||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
|
if model_params < 0:
|
||||||
|
model_params = 1000000000000000000000
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if args.fp16_unet:
|
if args.fp16_unet:
|
||||||
@@ -640,6 +647,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if fp8_dtype is not None:
|
if fp8_dtype is not None:
|
||||||
|
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||||
|
return fp8_dtype
|
||||||
|
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if model_params * 2 > free_model_memory:
|
if model_params * 2 > free_model_memory:
|
||||||
return fp8_dtype
|
return fp8_dtype
|
||||||
@@ -833,27 +843,21 @@ def force_channels_last():
|
|||||||
#TODO
|
#TODO
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
return r
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
non_blocking = device_supports_non_blocking(device)
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
device_supports_cast = True
|
|
||||||
elif tensor.dtype == torch.bfloat16:
|
|
||||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
|
||||||
device_supports_cast = True
|
|
||||||
elif is_intel_xpu():
|
|
||||||
device_supports_cast = True
|
|
||||||
|
|
||||||
non_blocking = device_should_use_non_blocking(device)
|
|
||||||
|
|
||||||
if device_supports_cast:
|
|
||||||
if copy:
|
|
||||||
if tensor.device == device:
|
|
||||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
@@ -892,7 +896,7 @@ def force_upcast_attention_dtype():
|
|||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -999,7 +1003,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||||
for x in nvidia_10_series:
|
for x in nvidia_10_series:
|
||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
return True
|
if WINDOWS or manual_cast:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False #weird linux behavior where fp32 is faster
|
||||||
|
|
||||||
if manual_cast:
|
if manual_cast:
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
@@ -1055,6 +1062,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_fp8_compute(device=None):
|
def supports_fp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 9:
|
if props.major >= 9:
|
||||||
return True
|
return True
|
||||||
@@ -1062,6 +1072,14 @@ def supports_fp8_compute(device=None):
|
|||||||
return False
|
return False
|
||||||
if props.minor < 9:
|
if props.minor < 9:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if WINDOWS:
|
||||||
|
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
|
|||||||
@@ -28,8 +28,20 @@ import comfy.utils
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
from comfy.types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
|
||||||
|
def string_to_seed(data):
|
||||||
|
crc = 0xFFFFFFFF
|
||||||
|
for byte in data:
|
||||||
|
if isinstance(byte, str):
|
||||||
|
byte = ord(byte)
|
||||||
|
crc ^= byte
|
||||||
|
for _ in range(8):
|
||||||
|
if crc & 1:
|
||||||
|
crc = (crc >> 1) ^ 0xEDB88320
|
||||||
|
else:
|
||||||
|
crc >>= 1
|
||||||
|
return crc ^ 0xFFFFFFFF
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
@@ -76,7 +88,36 @@ class LowVramPatch:
|
|||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
intermediate_dtype = weight.dtype
|
||||||
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||||
|
intermediate_dtype = torch.float32
|
||||||
|
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
||||||
|
|
||||||
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
|
def get_key_weight(model, key):
|
||||||
|
set_func = None
|
||||||
|
convert_func = None
|
||||||
|
op_keys = key.rsplit('.', 1)
|
||||||
|
if len(op_keys) < 2:
|
||||||
|
weight = comfy.utils.get_attr(model, key)
|
||||||
|
else:
|
||||||
|
op = comfy.utils.get_attr(model, op_keys[0])
|
||||||
|
try:
|
||||||
|
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
weight = getattr(op, op_keys[1])
|
||||||
|
if convert_func is not None:
|
||||||
|
weight = comfy.utils.get_attr(model, key)
|
||||||
|
|
||||||
|
return weight, set_func, convert_func
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
@@ -271,17 +312,23 @@ class ModelPatcher:
|
|||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
def get_key_patches(self, filter_prefix=None):
|
||||||
comfy.model_management.unload_model_clones(self)
|
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
p = {}
|
p = {}
|
||||||
for k in model_sd:
|
for k in model_sd:
|
||||||
if filter_prefix is not None:
|
if filter_prefix is not None:
|
||||||
if not k.startswith(filter_prefix):
|
if not k.startswith(filter_prefix):
|
||||||
continue
|
continue
|
||||||
|
bk = self.backup.get(k, None)
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, k)
|
||||||
|
if bk is not None:
|
||||||
|
weight = bk.weight
|
||||||
|
if convert_func is None:
|
||||||
|
convert_func = lambda a, **kwargs: a
|
||||||
|
|
||||||
if k in self.patches:
|
if k in self.patches:
|
||||||
p[k] = [model_sd[k]] + self.patches[k]
|
p[k] = [(weight, convert_func)] + self.patches[k]
|
||||||
else:
|
else:
|
||||||
p[k] = (model_sd[k],)
|
p[k] = [(weight, convert_func)]
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
@@ -297,8 +344,7 @@ class ModelPatcher:
|
|||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update or inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
@@ -308,23 +354,38 @@ class ModelPatcher:
|
|||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
if convert_func is not None:
|
||||||
|
temp_weight = convert_func(temp_weight, inplace=True)
|
||||||
|
|
||||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype)
|
if set_func is None:
|
||||||
if inplace_update:
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
if inplace_update:
|
||||||
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
load_completely = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m))
|
||||||
|
|
||||||
|
load_completely = []
|
||||||
|
loading.sort(reverse=True)
|
||||||
|
for x in loading:
|
||||||
|
n = x[1]
|
||||||
|
m = x[2]
|
||||||
|
module_mem = x[0]
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
@@ -356,9 +417,8 @@ class ModelPatcher:
|
|||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
mem_used = comfy.model_management.module_size(m)
|
mem_counter += module_mem
|
||||||
mem_counter += mem_used
|
load_completely.append((module_mem, n, m))
|
||||||
load_completely.append((mem_used, n, m))
|
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
|
|||||||
95
comfy/ops.py
95
comfy/ops.py
@@ -19,16 +19,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.float
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True):
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
|
||||||
return weight
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
@@ -43,12 +39,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
has_function = s.bias_function is not None
|
has_function = s.bias_function is not None
|
||||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
|
|
||||||
has_function = s.weight_function is not None
|
has_function = s.weight_function is not None
|
||||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
@@ -254,20 +250,25 @@ def fp8_linear(self, input):
|
|||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
tensor_2d = False
|
||||||
|
if len(input.shape) == 2:
|
||||||
|
tensor_2d = True
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
if scale_weight is None:
|
if scale_weight is None:
|
||||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
if scale_input is None:
|
|
||||||
scale_input = scale_weight
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
else:
|
||||||
|
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
@@ -277,7 +278,11 @@ def fp8_linear(self, input):
|
|||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
|
if tensor_2d:
|
||||||
|
return o.reshape(input.shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class fp8_ops(manual_cast):
|
class fp8_ops(manual_cast):
|
||||||
@@ -295,11 +300,63 @@ class fp8_ops(manual_cast):
|
|||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||||
|
class scaled_fp8_op(manual_cast):
|
||||||
|
class Linear(manual_cast.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if override_dtype is not None:
|
||||||
|
kwargs['dtype'] = override_dtype
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if not hasattr(self, 'scale_weight'):
|
||||||
|
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
|
|
||||||
|
if not scale_input:
|
||||||
|
self.scale_input = None
|
||||||
|
|
||||||
|
if not hasattr(self, 'scale_input'):
|
||||||
|
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
if fp8_matrix_mult:
|
||||||
|
out = fp8_linear(self, input)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
|
||||||
|
if weight.numel() < input.numel(): #TODO: optimize
|
||||||
|
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||||
|
|
||||||
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
|
if inplace:
|
||||||
|
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
return weight
|
||||||
|
else:
|
||||||
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
|
||||||
|
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||||
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||||
|
if inplace_update:
|
||||||
|
self.weight.data.copy_(weight)
|
||||||
|
else:
|
||||||
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
|
return scaled_fp8_op
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||||
|
|
||||||
|
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||||
|
return fp8_ops
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
if args.fast:
|
|
||||||
if comfy.model_management.supports_fp8_compute(load_device):
|
|
||||||
return fp8_ops
|
|
||||||
return manual_cast
|
return manual_cast
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from comfy import model_management
|
|||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import scipy
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
@@ -358,8 +358,11 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
|||||||
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||||
|
|
||||||
sigs = []
|
sigs = []
|
||||||
|
last_t = -1
|
||||||
for t in ts:
|
for t in ts:
|
||||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
if t != last_t:
|
||||||
|
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||||
|
last_t = t
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
@@ -570,8 +573,8 @@ class Sampler:
|
|||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis"]
|
"ipndm", "ipndm_v", "deis"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
|
|||||||
115
comfy/sd.py
115
comfy/sd.py
@@ -29,7 +29,6 @@ import comfy.text_encoders.long_clipl
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.supported_models_base
|
|
||||||
import comfy.taesd.taesd
|
import comfy.taesd.taesd
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
@@ -70,14 +69,14 @@ class CLIP:
|
|||||||
clip = target.clip
|
clip = target.clip
|
||||||
tokenizer = target.tokenizer
|
tokenizer = target.tokenizer
|
||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_options.get("load_device", model_management.text_encoder_device())
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
|
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
|
||||||
params['model_options'] = model_options
|
params['model_options'] = model_options
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
@@ -348,7 +347,7 @@ class VAE:
|
|||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
@@ -406,8 +405,47 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class TEModel(Enum):
|
||||||
|
CLIP_L = 1
|
||||||
|
CLIP_H = 2
|
||||||
|
CLIP_G = 3
|
||||||
|
T5_XXL = 4
|
||||||
|
T5_XL = 5
|
||||||
|
T5_BASE = 6
|
||||||
|
|
||||||
|
def detect_te_model(sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_G
|
||||||
|
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_H
|
||||||
|
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_L
|
||||||
|
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||||
|
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
|
if weight.shape[-1] == 4096:
|
||||||
|
return TEModel.T5_XXL
|
||||||
|
elif weight.shape[-1] == 2048:
|
||||||
|
return TEModel.T5_XL
|
||||||
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
|
return TEModel.T5_BASE
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def t5xxl_detect(clip_data):
|
||||||
|
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||||
|
|
||||||
|
dtype_t5 = None
|
||||||
|
for sd in clip_data:
|
||||||
|
if weight_name in sd:
|
||||||
|
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -421,64 +459,61 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
clip_target.params = {}
|
clip_target.params = {}
|
||||||
if len(clip_data) == 1:
|
if len(clip_data) == 1:
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
te_model = detect_te_model(clip_data[0])
|
||||||
|
if te_model == TEModel.CLIP_G:
|
||||||
if clip_type == CLIPType.STABLE_CASCADE:
|
if clip_type == CLIPType.STABLE_CASCADE:
|
||||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||||
|
elif clip_type == CLIPType.SD3:
|
||||||
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
elif te_model == TEModel.CLIP_H:
|
||||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
elif te_model == TEModel.T5_XXL:
|
||||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||||
dtype_t5 = weight.dtype
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
if weight.shape[-1] == 4096:
|
elif te_model == TEModel.T5_XL:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
elif weight.shape[-1] == 2048:
|
elif te_model == TEModel.T5_BASE:
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
|
||||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
if clip_type == CLIPType.SD3:
|
||||||
if w is not None and w.shape[0] == 248:
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
|
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
elif len(clip_data) == 2:
|
elif len(clip_data) == 2:
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||||
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
||||||
elif clip_type == CLIPType.FLUX:
|
elif clip_type == CLIPType.FLUX:
|
||||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
||||||
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
|
|
||||||
dtype_t5 = None
|
|
||||||
if weight is not None:
|
|
||||||
dtype_t5 = weight.dtype
|
|
||||||
|
|
||||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif len(clip_data) == 3:
|
elif len(clip_data) == 3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
|
tokenizer_data = {}
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@@ -544,11 +579,11 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None:
|
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
unet_dtype = model_options.get("weight_dtype", None)
|
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||||
|
|
||||||
if unet_dtype is None:
|
if unet_dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
@@ -562,7 +597,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model.load_model_weights(sd, diffusion_model_prefix)
|
model.load_model_weights(sd, diffusion_model_prefix)
|
||||||
|
|
||||||
@@ -614,6 +648,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
|
|
||||||
@@ -640,14 +676,21 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||||
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
|
if model_options.get("fp8_optimizations", False):
|
||||||
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
"pooled",
|
"pooled",
|
||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||||
@@ -94,11 +94,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
|
scaled_fp8 = None
|
||||||
|
|
||||||
if operations is None:
|
if operations is None:
|
||||||
operations = comfy.ops.manual_cast
|
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||||
|
else:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||||
|
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@@ -542,6 +551,7 @@ class SD1Tokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -570,6 +580,7 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
|
||||||
|
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SDXLTokenizer:
|
class SDXLTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -40,7 +41,8 @@ class SDXLTokenizer:
|
|||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = set([dtype])
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
token_weight_pairs_l = token_weight_pairs["l"]
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
cut_to = min(l_out.shape[1], g_out.shape[1])
|
||||||
|
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
|||||||
@@ -529,12 +529,11 @@ class SD3(supported_models_base.BASE):
|
|||||||
clip_l = True
|
clip_l = True
|
||||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
clip_g = True
|
clip_g = True
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
if t5_key in state_dict:
|
if "dtype_t5" in t5_detect:
|
||||||
t5 = True
|
t5 = True
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
|
||||||
|
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
|
||||||
|
|
||||||
class StableAudio(supported_models_base.BASE):
|
class StableAudio(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@@ -653,11 +652,8 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
dtype_t5 = None
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||||
if t5_key in state_dict:
|
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
|
||||||
|
|
||||||
class FluxSchnell(Flux):
|
class FluxSchnell(Flux):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ class BASE:
|
|||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
|
scaled_fp8 = None
|
||||||
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
@@ -71,6 +73,7 @@ class BASE:
|
|||||||
self.unet_config = unet_config.copy()
|
self.unet_config = unet_config.copy()
|
||||||
self.sampling_settings = self.sampling_settings.copy()
|
self.sampling_settings = self.sampling_settings.copy()
|
||||||
self.latent_format = self.latent_format()
|
self.latent_format = self.latent_format()
|
||||||
|
self.optimizations = self.optimizations.copy()
|
||||||
for x in self.unet_extra_config:
|
for x in self.unet_extra_config:
|
||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,21 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import comfy.text_encoders.t5
|
import comfy.text_encoders.t5
|
||||||
|
import comfy.text_encoders.sd3_clip
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
class FluxTokenizer:
|
class FluxTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -38,8 +35,9 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
|
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes = set([dtype, dtype_t5])
|
self.dtypes = set([dtype, dtype_t5])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
@@ -64,8 +62,11 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def flux_clip(dtype_t5=None):
|
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
class FluxClipModel_(FluxClipModel):
|
class FluxClipModel_(FluxClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
|||||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class LongClipModel_(sd1_clip.SDClipModel):
|
class LongClipModel_(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, *args, **kwargs):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||||
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||||
|
|
||||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||||
|
|
||||||
|
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||||
|
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is None:
|
||||||
|
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is not None and w.shape[0] == 248:
|
||||||
|
tokenizer_data = tokenizer_data.copy()
|
||||||
|
model_options = model_options.copy()
|
||||||
|
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||||
|
model_options["clip_l_class"] = LongClipModel_
|
||||||
|
return tokenizer_data, model_options
|
||||||
|
|||||||
@@ -8,9 +8,27 @@ import comfy.model_management
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||||
|
if t5xxl_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def t5_xxl_detect(state_dict, prefix=""):
|
||||||
|
out = {}
|
||||||
|
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
|
||||||
|
if t5_key in state_dict:
|
||||||
|
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
|
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||||
|
if scaled_fp8_key in state_dict:
|
||||||
|
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -20,7 +38,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
@@ -38,11 +57,12 @@ class SD3Tokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_l = None
|
self.clip_l = None
|
||||||
@@ -55,7 +75,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
if t5:
|
if t5:
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
self.t5_attention_mask = t5_attention_mask
|
||||||
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
|
||||||
self.dtypes.add(dtype_t5)
|
self.dtypes.add(dtype_t5)
|
||||||
else:
|
else:
|
||||||
self.t5xxl = None
|
self.t5xxl = None
|
||||||
@@ -85,6 +106,7 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
lg_out = None
|
lg_out = None
|
||||||
pooled = None
|
pooled = None
|
||||||
out = None
|
out = None
|
||||||
|
extra = {}
|
||||||
|
|
||||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||||
if self.clip_l is not None:
|
if self.clip_l is not None:
|
||||||
@@ -95,7 +117,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
if self.clip_g is not None:
|
if self.clip_g is not None:
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
cut_to = min(lg_out.shape[1], g_out.shape[1])
|
||||||
|
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
|
||||||
else:
|
else:
|
||||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||||
else:
|
else:
|
||||||
@@ -108,7 +131,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
if self.t5xxl is not None:
|
if self.t5xxl is not None:
|
||||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||||
|
t5_out, t5_pooled = t5_output[:2]
|
||||||
|
if self.t5_attention_mask:
|
||||||
|
extra["attention_mask"] = t5_output[2]["attention_mask"]
|
||||||
|
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
else:
|
else:
|
||||||
@@ -120,7 +147,7 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
if pooled is None:
|
if pooled is None:
|
||||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
return out, pooled
|
return out, pooled, extra
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@@ -130,8 +157,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
|
|||||||
for k in sd.keys():
|
for k in sd.keys():
|
||||||
if k.startswith(prefix):
|
if k.startswith(prefix):
|
||||||
w = sd[k]
|
w = sd[k]
|
||||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||||
|
|
||||||
if len(dtypes) == 0:
|
if len(dtypes) == 0:
|
||||||
return None
|
return None
|
||||||
@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
||||||
|
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
@@ -711,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
|||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
||||||
|
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
||||||
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
@@ -720,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
|
|
||||||
|
# handle entire input fitting in a single tile
|
||||||
|
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||||
|
output[b:b+1] = function(s).to(output_device)
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
|
|
||||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
|
for it in itertools.product(*positions):
|
||||||
s_in = s
|
s_in = s
|
||||||
upscaled = []
|
upscaled = []
|
||||||
|
|
||||||
@@ -732,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(pos * upscale_amount))
|
upscaled.append(round(pos * upscale_amount))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
feather = round(overlap * upscale_amount)
|
||||||
|
|
||||||
for t in range(feather):
|
for t in range(feather):
|
||||||
for d in range(2, dims + 2):
|
for d in range(2, dims + 2):
|
||||||
m = mask.narrow(d, t, 1)
|
a = (t + 1) / feather
|
||||||
m *= ((1.0/feather) * (t + 1))
|
mask.narrow(d, t, 1).mul_(a)
|
||||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||||
m *= ((1.0/feather) * (t + 1))
|
|
||||||
|
|
||||||
o = out
|
o = out
|
||||||
o_d = out_div
|
o_d = out_div
|
||||||
@@ -748,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
o += ps * mask
|
o.add_(ps * mask)
|
||||||
o_d += mask
|
o_d.add_(mask)
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence, Mapping
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def include_unique_id_in_input(class_type: str) -> bool:
|
||||||
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
class CacheKeySet:
|
class CacheKeySet:
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
@@ -56,6 +66,8 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = (node_id, node["class_type"])
|
self.keys[node_id] = (node_id, node["class_type"])
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
@@ -74,6 +86,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
@@ -87,11 +101,14 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return to_hashable(signature)
|
return to_hashable(signature)
|
||||||
|
|
||||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
# This node doesn't exist -- we can't cache it.
|
||||||
|
return [float("NaN")]
|
||||||
node = dynprompt.get_node(node_id)
|
node = dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||||
signature.append(node_id)
|
signature.append(node_id)
|
||||||
inputs = node["inputs"]
|
inputs = node["inputs"]
|
||||||
for key in sorted(inputs.keys()):
|
for key in sorted(inputs.keys()):
|
||||||
@@ -112,6 +129,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return ancestors, order_mapping
|
return ancestors, order_mapping
|
||||||
|
|
||||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
return
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
input_keys = sorted(inputs.keys())
|
input_keys = sorted(inputs.keys())
|
||||||
for key in input_keys:
|
for key in input_keys:
|
||||||
|
|||||||
@@ -99,30 +99,44 @@ class TopologicalSort:
|
|||||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
self.add_node(from_node_id)
|
if not self.is_cached(from_node_id):
|
||||||
if to_node_id not in self.blocking[from_node_id]:
|
self.add_node(from_node_id)
|
||||||
self.blocking[from_node_id][to_node_id] = {}
|
if to_node_id not in self.blocking[from_node_id]:
|
||||||
self.blockCount[to_node_id] += 1
|
self.blocking[from_node_id][to_node_id] = {}
|
||||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
self.blockCount[to_node_id] += 1
|
||||||
|
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||||
|
|
||||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||||
if unique_id in self.pendingNodes:
|
node_ids = [node_unique_id]
|
||||||
return
|
links = []
|
||||||
self.pendingNodes[unique_id] = True
|
|
||||||
self.blockCount[unique_id] = 0
|
|
||||||
self.blocking[unique_id] = {}
|
|
||||||
|
|
||||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
while len(node_ids) > 0:
|
||||||
for input_name in inputs:
|
unique_id = node_ids.pop()
|
||||||
value = inputs[input_name]
|
if unique_id in self.pendingNodes:
|
||||||
if is_link(value):
|
continue
|
||||||
from_node_id, from_socket = value
|
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
self.pendingNodes[unique_id] = True
|
||||||
continue
|
self.blockCount[unique_id] = 0
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
self.blocking[unique_id] = {}
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
|
||||||
if include_lazy or not is_lazy:
|
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
for input_name in inputs:
|
||||||
|
value = inputs[input_name]
|
||||||
|
if is_link(value):
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
|
continue
|
||||||
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||||
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
|
node_ids.append(from_node_id)
|
||||||
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
self.add_strong_link(*link)
|
||||||
|
|
||||||
|
def is_cached(self, node_id):
|
||||||
|
return False
|
||||||
|
|
||||||
def get_ready_nodes(self):
|
def get_ready_nodes(self):
|
||||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||||
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.output_cache = output_cache
|
self.output_cache = output_cache
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def is_cached(self, node_id):
|
||||||
if self.output_cache.get(from_node_id) is not None:
|
return self.output_cache.get(node_id) is not None
|
||||||
# Nothing to do
|
|
||||||
return
|
|
||||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
|
||||||
|
|
||||||
def stage_node_execution(self):
|
def stage_node_execution(self):
|
||||||
assert self.staged_node_id is None
|
assert self.staged_node_id is None
|
||||||
|
|||||||
@@ -16,14 +16,15 @@ class EmptyLatentAudio:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
def generate(self, seconds):
|
def generate(self, seconds, batch_size):
|
||||||
batch_size = 1
|
|
||||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||||
return ({"samples":latent, "type": "audio"}, )
|
return ({"samples":latent, "type": "audio"}, )
|
||||||
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
|
|||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
|
std[std < 1.0] = 1.0
|
||||||
|
audio /= std
|
||||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||||
|
|
||||||
|
|
||||||
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
|
|||||||
}
|
}
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio:
|
||||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [
|
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||||
f for f in os.listdir(input_dir)
|
|
||||||
if (os.path.isfile(os.path.join(input_dir, f))
|
|
||||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||||
|
|
||||||
CATEGORY = "audio"
|
CATEGORY = "audio"
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||||
|
import nodes
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class SetUnionControlNetType:
|
class SetUnionControlNetType:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
|
|||||||
|
|
||||||
return (control_net,)
|
return (control_net,)
|
||||||
|
|
||||||
|
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
FUNCTION = "apply_inpaint_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
|
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||||
|
extra_concat = []
|
||||||
|
if control_net.concat_mask:
|
||||||
|
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
|
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
||||||
|
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||||
|
extra_concat = [mask]
|
||||||
|
|
||||||
|
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SetUnionControlNetType": SetUnionControlNetType,
|
"SetUnionControlNetType": SetUnionControlNetType,
|
||||||
|
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
|
|||||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
class LaplaceScheduler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
|
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
|
||||||
|
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
|
||||||
|
return (sigmas, )
|
||||||
|
|
||||||
|
|
||||||
class SDTurboScheduler:
|
class SDTurboScheduler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"KarrasScheduler": KarrasScheduler,
|
"KarrasScheduler": KarrasScheduler,
|
||||||
"ExponentialScheduler": ExponentialScheduler,
|
"ExponentialScheduler": ExponentialScheduler,
|
||||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||||
|
"LaplaceScheduler": LaplaceScheduler,
|
||||||
"VPScheduler": VPScheduler,
|
"VPScheduler": VPScheduler,
|
||||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||||
"SDTurboScheduler": SDTurboScheduler,
|
"SDTurboScheduler": SDTurboScheduler,
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class HypernetworkLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy_extras.nodes_post_processing
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent):
|
def reshape_latent_to(target_shape, latent):
|
||||||
@@ -145,6 +146,131 @@ class LatentBatchSeedBehavior:
|
|||||||
|
|
||||||
return (samples_out,)
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentApplyOperation:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"operation": ("LATENT_OPERATION",),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, samples, operation):
|
||||||
|
samples_out = samples.copy()
|
||||||
|
|
||||||
|
s1 = samples["samples"]
|
||||||
|
samples_out["samples"] = operation(latent=s1)
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentApplyOperationCFG:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"operation": ("LATENT_OPERATION",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model, operation):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
def pre_cfg_function(args):
|
||||||
|
conds_out = args["conds_out"]
|
||||||
|
if len(conds_out) == 2:
|
||||||
|
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
|
||||||
|
else:
|
||||||
|
conds_out[0] = operation(latent=conds_out[0])
|
||||||
|
return conds_out
|
||||||
|
|
||||||
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
class LatentOperationTonemapReinhard:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, multiplier):
|
||||||
|
def tonemap_reinhard(latent, **kwargs):
|
||||||
|
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||||
|
normalized_latent = latent / latent_vector_magnitude
|
||||||
|
|
||||||
|
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||||
|
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||||
|
|
||||||
|
top = (std * 5 + mean) * multiplier
|
||||||
|
|
||||||
|
#reinhard
|
||||||
|
latent_vector_magnitude *= (1.0 / top)
|
||||||
|
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
|
||||||
|
new_magnitude *= top
|
||||||
|
|
||||||
|
return normalized_latent * new_magnitude
|
||||||
|
return (tonemap_reinhard,)
|
||||||
|
|
||||||
|
class LatentOperationSharpen:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"sharpen_radius": ("INT", {
|
||||||
|
"default": 9,
|
||||||
|
"min": 1,
|
||||||
|
"max": 31,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
"sigma": ("FLOAT", {
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.1,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.1
|
||||||
|
}),
|
||||||
|
"alpha": ("FLOAT", {
|
||||||
|
"default": 0.1,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 5.0,
|
||||||
|
"step": 0.01
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, sharpen_radius, sigma, alpha):
|
||||||
|
def sharpen(latent, **kwargs):
|
||||||
|
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||||
|
normalized_latent = latent / luminance
|
||||||
|
channels = latent.shape[1]
|
||||||
|
|
||||||
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
|
kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
|
||||||
|
center = kernel_size // 2
|
||||||
|
|
||||||
|
kernel *= alpha * -10
|
||||||
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||||
|
|
||||||
|
padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||||
|
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||||
|
|
||||||
|
return luminance * sharpened
|
||||||
|
return (sharpen,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentAdd": LatentAdd,
|
"LatentAdd": LatentAdd,
|
||||||
"LatentSubtract": LatentSubtract,
|
"LatentSubtract": LatentSubtract,
|
||||||
@@ -152,4 +278,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentInterpolate": LatentInterpolate,
|
"LatentInterpolate": LatentInterpolate,
|
||||||
"LatentBatch": LatentBatch,
|
"LatentBatch": LatentBatch,
|
||||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||||
|
"LatentApplyOperation": LatentApplyOperation,
|
||||||
|
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
||||||
|
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
||||||
|
"LatentOperationSharpen": LatentOperationSharpen,
|
||||||
}
|
}
|
||||||
|
|||||||
119
comfy_extras/nodes_lora_extract.py
Normal file
119
comfy_extras/nodes_lora_extract.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
def extract_lora(diff, rank):
|
||||||
|
conv2d = (len(diff.shape) == 4)
|
||||||
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
out_dim, in_dim = diff.size()[0:2]
|
||||||
|
rank = min(rank, in_dim, out_dim)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
if conv2d_3x3:
|
||||||
|
diff = diff.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
diff = diff.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(diff.float())
|
||||||
|
U = U[:, :rank]
|
||||||
|
S = S[:rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
|
U = U.clamp(low_val, hi_val)
|
||||||
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
return (U, Vh)
|
||||||
|
|
||||||
|
class LORAType(Enum):
|
||||||
|
STANDARD = 0
|
||||||
|
FULL_DIFF = 1
|
||||||
|
|
||||||
|
LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||||
|
"full_diff": LORAType.FULL_DIFF}
|
||||||
|
|
||||||
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||||
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||||
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
weight_diff = sd[k]
|
||||||
|
if lora_type == LORAType.STANDARD:
|
||||||
|
if weight_diff.ndim < 2:
|
||||||
|
if bias_diff:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out = extract_lora(weight_diff, rank)
|
||||||
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||||
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||||
|
except:
|
||||||
|
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||||
|
elif lora_type == LORAType.FULL_DIFF:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
|
||||||
|
elif bias_diff and k.endswith(".bias"):
|
||||||
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
|
return output_sd
|
||||||
|
|
||||||
|
class LoraSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||||
|
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
||||||
|
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||||
|
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
||||||
|
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||||
|
if model_diff is None and text_encoder_diff is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
|
||||||
|
output_sd = {}
|
||||||
|
if model_diff is not None:
|
||||||
|
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
if text_encoder_diff is not None:
|
||||||
|
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LoraSave": LoraSave
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"LoraSave": "Extract and Save Lora"
|
||||||
|
}
|
||||||
@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "model_patches/unet"
|
||||||
|
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class PerpNeg:
|
|||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def patch(self, model, empty_conditioning, neg_scale):
|
def patch(self, model, empty_conditioning, neg_scale):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class PhotoMakerLoader:
|
|||||||
CATEGORY = "_for_testing/photomaker"
|
CATEGORY = "_for_testing/photomaker"
|
||||||
|
|
||||||
def load_photomaker_model(self, photomaker_model_name):
|
def load_photomaker_model(self, photomaker_model_name):
|
||||||
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
|
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
|
||||||
photomaker_model = PhotoMakerIDEncoder()
|
photomaker_model = PhotoMakerIDEncoder()
|
||||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||||
if "id_encoder" in data:
|
if "id_encoder" in data:
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ class TripleCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||||
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
|
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
|
|||||||
CATEGORY = "latent/sd3"
|
CATEGORY = "latent/sd3"
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3:
|
||||||
@@ -93,6 +93,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
|||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
}}
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
"TripleCLIPLoader": TripleCLIPLoader,
|
||||||
@@ -103,5 +104,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
|
|||||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
CATEGORY = "_for_testing/stable_cascade"
|
CATEGORY = "_for_testing/stable_cascade"
|
||||||
|
|
||||||
def generate(self, image, vae):
|
def generate(self, image, vae):
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class TomePatchModel:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "model_patches/unet"
|
||||||
|
|
||||||
def patch(self, model, ratio):
|
def patch(self, model, ratio):
|
||||||
self.u = None
|
self.u = None
|
||||||
|
|||||||
22
comfy_extras/nodes_torch_compile.py
Normal file
22
comfy_extras/nodes_torch_compile.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class TorchCompileModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"backend": (["inductor", "cudagraphs"],),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model, backend):
|
||||||
|
m = model.clone()
|
||||||
|
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TorchCompileModel": TorchCompileModel,
|
||||||
|
}
|
||||||
@@ -25,7 +25,7 @@ class UpscaleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
|
|||||||
CATEGORY = "loaders/video_models"
|
CATEGORY = "loaders/video_models"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (out[0], out[3], out[2])
|
return (out[0], out[3], out[2])
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
|
|||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class SaveImageWebsocket:
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def IS_CHANGED(s, images):
|
def IS_CHANGED(s, images):
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -179,7 +179,13 @@ def merge_result_data(results, obj):
|
|||||||
# merge node execution results
|
# merge node execution results
|
||||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
if is_list:
|
if is_list:
|
||||||
output.append([x for o in results for x in o[i]])
|
value = []
|
||||||
|
for o in results:
|
||||||
|
if isinstance(o[i], ExecutionBlocker):
|
||||||
|
value.append(o[i])
|
||||||
|
else:
|
||||||
|
value.extend(o[i])
|
||||||
|
output.append(value)
|
||||||
else:
|
else:
|
||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -25,11 +25,16 @@ a111:
|
|||||||
|
|
||||||
#comfyui:
|
#comfyui:
|
||||||
# base_path: path/to/comfyui/
|
# base_path: path/to/comfyui/
|
||||||
|
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
|
||||||
|
# #is_default: true
|
||||||
# checkpoints: models/checkpoints/
|
# checkpoints: models/checkpoints/
|
||||||
# clip: models/clip/
|
# clip: models/clip/
|
||||||
# clip_vision: models/clip_vision/
|
# clip_vision: models/clip_vision/
|
||||||
# configs: models/configs/
|
# configs: models/configs/
|
||||||
# controlnet: models/controlnet/
|
# controlnet: models/controlnet/
|
||||||
|
# diffusion_models: |
|
||||||
|
# models/diffusion_models
|
||||||
|
# models/unet
|
||||||
# embeddings: models/embeddings/
|
# embeddings: models/embeddings/
|
||||||
# loras: models/loras/
|
# loras: models/loras/
|
||||||
# upscale_models: models/upscale_models/
|
# upscale_models: models/upscale_models/
|
||||||
|
|||||||
103
folder_paths.py
103
folder_paths.py
@@ -2,7 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import mimetypes
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Set, List, Dict, Tuple, Literal
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
|
|
||||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||||
@@ -44,6 +46,40 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
|
|||||||
|
|
||||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
|
||||||
|
class CacheHelper:
|
||||||
|
"""
|
||||||
|
Helper class for managing file list cache data.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
self.active = False
|
||||||
|
|
||||||
|
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
|
||||||
|
if not self.active:
|
||||||
|
return default
|
||||||
|
return self.cache.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||||
|
if self.active:
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.active = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.active = False
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
cache_helper = CacheHelper()
|
||||||
|
|
||||||
|
extension_mimetypes_cache = {
|
||||||
|
"webp" : "image",
|
||||||
|
}
|
||||||
|
|
||||||
def map_legacy(folder_name: str) -> str:
|
def map_legacy(folder_name: str) -> str:
|
||||||
legacy = {"unet": "diffusion_models"}
|
legacy = {"unet": "diffusion_models"}
|
||||||
return legacy.get(folder_name, folder_name)
|
return legacy.get(folder_name, folder_name)
|
||||||
@@ -78,6 +114,13 @@ def get_input_directory() -> str:
|
|||||||
global input_directory
|
global input_directory
|
||||||
return input_directory
|
return input_directory
|
||||||
|
|
||||||
|
def get_user_directory() -> str:
|
||||||
|
return user_directory
|
||||||
|
|
||||||
|
def set_user_directory(user_dir: str) -> None:
|
||||||
|
global user_directory
|
||||||
|
user_directory = user_dir
|
||||||
|
|
||||||
|
|
||||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||||
def get_directory_by_type(type_name: str) -> str | None:
|
def get_directory_by_type(type_name: str) -> str | None:
|
||||||
@@ -89,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
|
|||||||
return get_input_directory()
|
return get_input_directory()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
|
filter_files_content_types(files, ["image", "audio", "video"])
|
||||||
|
"""
|
||||||
|
global extension_mimetypes_cache
|
||||||
|
result = []
|
||||||
|
for file in files:
|
||||||
|
extension = file.split('.')[-1]
|
||||||
|
if extension not in extension_mimetypes_cache:
|
||||||
|
mime_type, _ = mimetypes.guess_type(file, strict=False)
|
||||||
|
if not mime_type:
|
||||||
|
continue
|
||||||
|
content_type = mime_type.split('/')[0]
|
||||||
|
extension_mimetypes_cache[extension] = content_type
|
||||||
|
else:
|
||||||
|
content_type = extension_mimetypes_cache[extension]
|
||||||
|
|
||||||
|
if content_type in content_types:
|
||||||
|
result.append(file)
|
||||||
|
return result
|
||||||
|
|
||||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
# otherwise use default_path as base_dir
|
# otherwise use default_path as base_dir
|
||||||
@@ -130,11 +195,14 @@ def exists_annotated_filepath(name) -> bool:
|
|||||||
return os.path.exists(filepath)
|
return os.path.exists(filepath)
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
if is_default:
|
||||||
|
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
|
||||||
|
else:
|
||||||
|
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||||
else:
|
else:
|
||||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||||
|
|
||||||
@@ -166,8 +234,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
|
|||||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||||
for file_name in filenames:
|
for file_name in filenames:
|
||||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
try:
|
||||||
result.append(relative_path)
|
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||||
|
result.append(relative_path)
|
||||||
|
except:
|
||||||
|
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||||
|
continue
|
||||||
|
|
||||||
for d in subdirs:
|
for d in subdirs:
|
||||||
path: str = os.path.join(dirpath, d)
|
path: str = os.path.join(dirpath, d)
|
||||||
@@ -200,6 +272,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
|
full_path = get_full_path(folder_name, filename)
|
||||||
|
if full_path is None:
|
||||||
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
@@ -214,6 +294,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
|
|||||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||||
|
|
||||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||||
|
strong_cache = cache_helper.get(folder_name)
|
||||||
|
if strong_cache is not None:
|
||||||
|
return strong_cache
|
||||||
|
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
@@ -242,6 +326,7 @@ def get_filename_list(folder_name: str) -> list[str]:
|
|||||||
out = get_filename_list_(folder_name)
|
out = get_filename_list_(folder_name)
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
filename_list_cache[folder_name] = out
|
filename_list_cache[folder_name] = out
|
||||||
|
cache_helper.set(folder_name, out)
|
||||||
return list(out[0])
|
return list(out[0])
|
||||||
|
|
||||||
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
||||||
@@ -257,9 +342,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||||
input = input.replace("%width%", str(image_width))
|
input = input.replace("%width%", str(image_width))
|
||||||
input = input.replace("%height%", str(image_height))
|
input = input.replace("%height%", str(image_height))
|
||||||
|
now = time.localtime()
|
||||||
|
input = input.replace("%year%", str(now.tm_year))
|
||||||
|
input = input.replace("%month%", str(now.tm_mon).zfill(2))
|
||||||
|
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
||||||
|
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
||||||
|
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
||||||
|
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||||
return input
|
return input
|
||||||
|
|
||||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
if "%" in filename_prefix:
|
||||||
|
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||||
|
|
||||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import folder_paths
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = 512
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
|
|
||||||
def preview_to_image(latent_image):
|
def preview_to_image(latent_image):
|
||||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
@@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self, latent_rgb_factors):
|
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||||
|
self.latent_rgb_factors_bias = None
|
||||||
|
if latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||||
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
if self.latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||||
|
|
||||||
|
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||||
|
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||||
|
|
||||||
return preview_to_image(latent_image)
|
return preview_to_image(latent_image)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
if latent_format.latent_rgb_factors is not None:
|
if latent_format.latent_rgb_factors is not None:
|
||||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
def prepare_callback(model, steps, x0_output_dict=None):
|
def prepare_callback(model, steps, x0_output_dict=None):
|
||||||
|
|||||||
45
main.py
45
main.py
@@ -6,6 +6,10 @@ import importlib.util
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from app.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
setup_logger(log_level=args.verbose)
|
||||||
|
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
@@ -59,6 +63,7 @@ import threading
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import utils.extra_config
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
@@ -81,7 +86,6 @@ if args.windows_standalone_build:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import yaml
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@@ -156,7 +160,10 @@ def prompt_worker(q, server):
|
|||||||
need_gc = False
|
need_gc = False
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
addresses = []
|
||||||
|
for addr in address.split(","):
|
||||||
|
addresses.append((addr, port))
|
||||||
|
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
@@ -176,27 +183,6 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def load_extra_path_config(yaml_path):
|
|
||||||
with open(yaml_path, 'r') as stream:
|
|
||||||
config = yaml.safe_load(stream)
|
|
||||||
for c in config:
|
|
||||||
conf = config[c]
|
|
||||||
if conf is None:
|
|
||||||
continue
|
|
||||||
base_path = None
|
|
||||||
if "base_path" in conf:
|
|
||||||
base_path = conf.pop("base_path")
|
|
||||||
for x in conf:
|
|
||||||
for y in conf[x].split("\n"):
|
|
||||||
if len(y) == 0:
|
|
||||||
continue
|
|
||||||
full_path = y
|
|
||||||
if base_path is not None:
|
|
||||||
full_path = os.path.join(base_path, full_path)
|
|
||||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
|
||||||
folder_paths.add_model_folder_path(x, full_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
@@ -218,11 +204,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
load_extra_path_config(extra_model_paths_config_path)
|
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||||
|
|
||||||
if args.extra_model_paths_config:
|
if args.extra_model_paths_config:
|
||||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||||
load_extra_path_config(config_path)
|
utils.extra_config.load_extra_path_config(config_path)
|
||||||
|
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||||
|
|
||||||
@@ -243,21 +229,30 @@ if __name__ == "__main__":
|
|||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||||
|
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
logging.info(f"Setting input directory to: {input_dir}")
|
logging.info(f"Setting input directory to: {input_dir}")
|
||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
|
if args.user_directory:
|
||||||
|
user_dir = os.path.abspath(args.user_directory)
|
||||||
|
logging.info(f"Setting user directory to: {user_dir}")
|
||||||
|
folder_paths.set_user_directory(user_dir)
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
def startup_server(scheme, address, port):
|
def startup_server(scheme, address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
if os.name == 'nt' and address == '0.0.0.0':
|
if os.name == 'nt' and address == '0.0.0.0':
|
||||||
address = '127.0.0.1'
|
address = '127.0.0.1'
|
||||||
|
if ':' in address:
|
||||||
|
address = "[{}]".format(address)
|
||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# model_manager/__init__.py
|
# model_manager/__init__.py
|
||||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
#NOTE: This was an experiment and WILL BE REMOVED
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -17,6 +18,7 @@ class DownloadStatusType(Enum):
|
|||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadModelStatus():
|
class DownloadModelStatus():
|
||||||
status: str
|
status: str
|
||||||
@@ -38,10 +40,12 @@ class DownloadModelStatus():
|
|||||||
"already_existed": self.already_existed
|
"already_existed": self.already_existed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_directory: str,
|
||||||
|
folder_path: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
"""
|
"""
|
||||||
@@ -54,23 +58,17 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
The name of the model file to be downloaded. This will be the filename on disk.
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
model_url (str):
|
model_url (str):
|
||||||
The URL from which to download the model.
|
The URL from which to download the model.
|
||||||
model_sub_directory (str):
|
model_directory (str):
|
||||||
The subdirectory within the main models directory where the model
|
The subdirectory within the main models directory where the model
|
||||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
An asynchronous function to call with progress updates.
|
An asynchronous function to call with progress updates.
|
||||||
|
folder_path (str);
|
||||||
|
Path to which model folder should be used as the root.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DownloadModelStatus: The result of the download operation.
|
DownloadModelStatus: The result of the download operation.
|
||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
"Invalid model subdirectory",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validate_filename(model_name):
|
if not validate_filename(model_name):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
@@ -79,52 +77,67 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
if not model_directory in folder_names_and_paths:
|
||||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not folder_path in get_folder_paths(model_directory):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = create_model_path(model_name, folder_path)
|
||||||
|
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||||
if existing_file:
|
if existing_file:
|
||||||
return existing_file
|
return existing_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logging.info(f"Downloading {model_name} from {model_url}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in downloading model: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
os.makedirs(full_model_dir, exist_ok=True)
|
file_path = os.path.join(folder_path, model_name)
|
||||||
file_path = os.path.join(full_model_dir, model_name)
|
|
||||||
|
|
||||||
# Ensure the resulting path is still within the base directory
|
# Ensure the resulting path is still within the base directory
|
||||||
abs_file_path = os.path.abspath(file_path)
|
abs_file_path = os.path.abspath(file_path)
|
||||||
abs_base_dir = os.path.abspath(str(models_base_dir))
|
abs_base_dir = os.path.abspath(folder_path)
|
||||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
relative_path = '/'.join([model_directory, model_name])
|
|
||||||
return file_path, relative_path
|
|
||||||
|
|
||||||
async def check_file_exists(file_path: str,
|
async def check_file_exists(file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
) -> Optional[DownloadModelStatus]:
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -133,7 +146,6 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
file_path: str,
|
file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str,
|
|
||||||
interval: float = 1.0) -> DownloadModelStatus:
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
try:
|
try:
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
@@ -144,10 +156,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
with open(file_path, 'wb') as f:
|
temp_file_path = file_path + '.tmp'
|
||||||
|
with open(temp_file_path, 'wb') as f:
|
||||||
chunk_iterator = response.content.iter_chunked(8192)
|
chunk_iterator = response.content.iter_chunked(8192)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -160,49 +173,30 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
if time.time() - last_update_time >= interval:
|
if time.time() - last_update_time >= interval:
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
|
os.rename(temp_file_path, file_path)
|
||||||
|
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
async def handle_download_error(e: Exception,
|
async def handle_download_error(e: Exception,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||||
relative_path: str) -> DownloadModelStatus:
|
) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that the model subdirectory is safe to install into.
|
|
||||||
Must not contain relative paths, nested paths or special characters
|
|
||||||
other than underscores and hyphens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_subdirectory (str): The subdirectory for the specific model type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the subdirectory is safe, False otherwise.
|
|
||||||
"""
|
|
||||||
if len(model_subdirectory) > 50:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def validate_filename(filename: str)-> bool:
|
def validate_filename(filename: str)-> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
56
nodes.py
56
nodes.py
@@ -511,10 +511,11 @@ class CheckpointLoader:
|
|||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def load_checkpoint(self, config_name, ckpt_name):
|
def load_checkpoint(self, config_name, ckpt_name):
|
||||||
config_path = folder_paths.get_full_path("configs", config_name)
|
config_path = folder_paths.get_full_path("configs", config_name)
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
class CheckpointLoaderSimple:
|
class CheckpointLoaderSimple:
|
||||||
@@ -535,7 +536,7 @@ class CheckpointLoaderSimple:
|
|||||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name):
|
def load_checkpoint(self, ckpt_name):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
@@ -577,7 +578,7 @@ class unCLIPCheckpointLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -624,7 +625,7 @@ class LoraLoader:
|
|||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
return (model, clip)
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||||
lora = None
|
lora = None
|
||||||
if self.loaded_lora is not None:
|
if self.loaded_lora is not None:
|
||||||
if self.loaded_lora[0] == lora_path:
|
if self.loaded_lora[0] == lora_path:
|
||||||
@@ -703,11 +704,11 @@ class VAELoader:
|
|||||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||||
|
|
||||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
|
||||||
for k in enc:
|
for k in enc:
|
||||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||||
|
|
||||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
|
||||||
for k in dec:
|
for k in dec:
|
||||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||||
|
|
||||||
@@ -738,7 +739,7 @@ class VAELoader:
|
|||||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
@@ -754,7 +755,7 @@ class ControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, control_net_name):
|
def load_controlnet(self, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@@ -770,7 +771,7 @@ class DiffControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, model, control_net_name):
|
def load_controlnet(self, model, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@@ -786,6 +787,7 @@ class ControlNetApply:
|
|||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
FUNCTION = "apply_controlnet"
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
DEPRECATED = True
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||||
@@ -815,7 +817,10 @@ class ControlNetApplyAdvanced:
|
|||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
},
|
||||||
|
"optional": {"vae": ("VAE", ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@@ -823,7 +828,7 @@ class ControlNetApplyAdvanced:
|
|||||||
|
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
|
||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (positive, negative)
|
return (positive, negative)
|
||||||
|
|
||||||
@@ -840,7 +845,7 @@ class ControlNetApplyAdvanced:
|
|||||||
if prev_cnet in cnets:
|
if prev_cnet in cnets:
|
||||||
c_net = cnets[prev_cnet]
|
c_net = cnets[prev_cnet]
|
||||||
else:
|
else:
|
||||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
|
||||||
c_net.set_previous_controlnet(prev_cnet)
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
cnets[prev_cnet] = c_net
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
@@ -856,7 +861,7 @@ class UNETLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
@@ -867,10 +872,13 @@ class UNETLoader:
|
|||||||
model_options = {}
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
model_options["dtype"] = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||||
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
model_options["fp8_optimizations"] = True
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
model_options["dtype"] = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
|
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
@@ -895,7 +903,7 @@ class CLIPLoader:
|
|||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@@ -912,8 +920,8 @@ class DualCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type):
|
def load_clip(self, clip_name1, clip_name2, type):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||||
if type == "sdxl":
|
if type == "sdxl":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
elif type == "sd3":
|
elif type == "sd3":
|
||||||
@@ -935,7 +943,7 @@ class CLIPVisionLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
||||||
clip_vision = comfy.clip_vision.load(clip_path)
|
clip_vision = comfy.clip_vision.load(clip_path)
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
@@ -965,7 +973,7 @@ class StyleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_style_model(self, style_model_name):
|
def load_style_model(self, style_model_name):
|
||||||
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
|
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
|
||||||
style_model = comfy.sd.load_style_model(style_model_path)
|
style_model = comfy.sd.load_style_model(style_model_path)
|
||||||
return (style_model,)
|
return (style_model,)
|
||||||
|
|
||||||
@@ -1030,7 +1038,7 @@ class GLIGENLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_gligen(self, gligen_name):
|
def load_gligen(self, gligen_name):
|
||||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
|
||||||
gligen = comfy.sd.load_gligen(gligen_path)
|
gligen = comfy.sd.load_gligen(gligen_path)
|
||||||
return (gligen,)
|
return (gligen,)
|
||||||
|
|
||||||
@@ -1916,8 +1924,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet",
|
"ControlNetApply": "Apply ControlNet (OLD)",
|
||||||
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
|
"ControlNetApplyAdvanced": "Apply ControlNet",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||||
@@ -2101,6 +2109,8 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_controlnet.py",
|
"nodes_controlnet.py",
|
||||||
"nodes_hunyuan.py",
|
"nodes_hunyuan.py",
|
||||||
"nodes_flux.py",
|
"nodes_flux.py",
|
||||||
|
"nodes_lora_extract.py",
|
||||||
|
"nodes_torch_compile.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
@@ -2129,3 +2139,5 @@ def init_extra_nodes(init_custom_nodes=True):
|
|||||||
else:
|
else:
|
||||||
logging.warning("Please do a: pip install -r requirements.txt")
|
logging.warning("Please do a: pip install -r requirements.txt")
|
||||||
logging.warning("")
|
logging.warning("")
|
||||||
|
|
||||||
|
return import_failed
|
||||||
|
|||||||
@@ -79,7 +79,7 @@
|
|||||||
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD1.5\n",
|
"# SD1.5\n",
|
||||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD2\n",
|
"# SD2\n",
|
||||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
|||||||
@@ -38,18 +38,20 @@ def get_images(ws, prompt):
|
|||||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||||
break #Execution is done
|
break #Execution is done
|
||||||
else:
|
else:
|
||||||
|
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
|
||||||
|
# bytesIO = BytesIO(out[8:])
|
||||||
|
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
|
||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = get_history(prompt_id)[prompt_id]
|
history = get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
for node_id in history['outputs']:
|
node_output = history['outputs'][node_id]
|
||||||
node_output = history['outputs'][node_id]
|
images_output = []
|
||||||
if 'images' in node_output:
|
if 'images' in node_output:
|
||||||
images_output = []
|
for image in node_output['images']:
|
||||||
for image in node_output['images']:
|
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
images_output.append(image_data)
|
||||||
images_output.append(image_data)
|
output_images[node_id] = images_output
|
||||||
output_images[node_id] = images_output
|
|
||||||
|
|
||||||
return output_images
|
return output_images
|
||||||
|
|
||||||
@@ -85,7 +87,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
@@ -152,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
@@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
|||||||
158
server.py
158
server.py
@@ -12,6 +12,8 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
|
import socket
|
||||||
|
import ipaddress
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -31,7 +33,6 @@ from model_filemanager import download_model, DownloadModelStatus
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
UNENCODED_PREVIEW_IMAGE = 2
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
@@ -39,9 +40,24 @@ class BinaryEventTypes:
|
|||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||||
logging.warning("send error: {}".format(err))
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
|
def get_comfyui_version():
|
||||||
|
comfyui_version = "unknown"
|
||||||
|
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
try:
|
||||||
|
import pygit2
|
||||||
|
repo = pygit2.Repository(repo_path)
|
||||||
|
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to get ComfyUI version: {e}")
|
||||||
|
return comfyui_version.strip()
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
@@ -66,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
|
|
||||||
return cors_middleware
|
return cors_middleware
|
||||||
|
|
||||||
|
def is_loopback(host):
|
||||||
|
if host is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if ipaddress.ip_address(host).is_loopback:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loopback = False
|
||||||
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
try:
|
||||||
|
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||||
|
for family, _, _, _, sockaddr in r:
|
||||||
|
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||||
|
return loopback
|
||||||
|
else:
|
||||||
|
loopback = True
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return loopback
|
||||||
|
|
||||||
|
|
||||||
|
def create_origin_only_middleware():
|
||||||
|
@web.middleware
|
||||||
|
async def origin_only_middleware(request: web.Request, handler):
|
||||||
|
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||||
|
#in that case the Host and Origin hostnames won't match
|
||||||
|
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||||
|
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||||
|
host = request.headers['Host']
|
||||||
|
origin = request.headers['Origin']
|
||||||
|
host_domain = host.lower()
|
||||||
|
parsed = urllib.parse.urlparse(origin)
|
||||||
|
origin_domain = parsed.netloc.lower()
|
||||||
|
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||||
|
|
||||||
|
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||||
|
loopback = is_loopback(host_domain_parsed.hostname)
|
||||||
|
|
||||||
|
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||||
|
host_domain = host_domain_parsed.hostname
|
||||||
|
if host_domain_parsed.port is None:
|
||||||
|
origin_domain = parsed.hostname
|
||||||
|
|
||||||
|
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||||
|
if host_domain != origin_domain:
|
||||||
|
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
response = web.Response()
|
||||||
|
else:
|
||||||
|
response = await handler(request)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return origin_only_middleware
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
@@ -85,6 +163,8 @@ class PromptServer():
|
|||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||||
|
else:
|
||||||
|
middlewares.append(create_origin_only_middleware())
|
||||||
|
|
||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
@@ -142,6 +222,12 @@ class PromptServer():
|
|||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
|
@routes.get("/models")
|
||||||
|
def list_model_types(request):
|
||||||
|
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||||
|
|
||||||
|
return web.json_response(model_types)
|
||||||
|
|
||||||
@routes.get("/models/{folder}")
|
@routes.get("/models/{folder}")
|
||||||
async def get_models(request):
|
async def get_models(request):
|
||||||
folder = request.match_info.get("folder", None)
|
folder = request.match_info.get("folder", None)
|
||||||
@@ -401,16 +487,25 @@ class PromptServer():
|
|||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def get_queue(request):
|
async def system_stats(request):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
cpu_device = comfy.model_management.torch.device("cpu")
|
||||||
|
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||||||
|
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": os.name,
|
||||||
|
"ram_total": ram_total,
|
||||||
|
"ram_free": ram_free,
|
||||||
|
"comfyui_version": get_comfyui_version(),
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
|
"argv": sys.argv
|
||||||
},
|
},
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
@@ -462,14 +557,15 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
out = {}
|
with folder_paths.cache_helper:
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
out = {}
|
||||||
try:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
out[x] = node_info(x)
|
try:
|
||||||
except Exception as e:
|
out[x] = node_info(x)
|
||||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
except Exception as e:
|
||||||
logging.error(traceback.format_exc())
|
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||||
return web.json_response(out)
|
logging.error(traceback.format_exc())
|
||||||
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/object_info/{node_class}")
|
@routes.get("/object_info/{node_class}")
|
||||||
async def get_object_info_node(request):
|
async def get_object_info_node(request):
|
||||||
@@ -583,18 +679,22 @@ class PromptServer():
|
|||||||
|
|
||||||
# Internal route. Should not be depended upon and is subject to change at any time.
|
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||||
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||||
|
# NOTE: This was an experiment and WILL BE REMOVED
|
||||||
@routes.post("/internal/models/download")
|
@routes.post("/internal/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
await self.send_json("download_progress", status.to_dict())
|
payload = status.to_dict()
|
||||||
|
payload['download_path'] = filename
|
||||||
|
await self.send_json("download_progress", payload)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
model_directory = data.get('model_directory')
|
model_directory = data.get('model_directory')
|
||||||
|
folder_path = data.get('folder_path')
|
||||||
model_filename = data.get('model_filename')
|
model_filename = data.get('model_filename')
|
||||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename:
|
if not url or not model_directory or not model_filename or not folder_path:
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
session = self.client_session
|
session = self.client_session
|
||||||
@@ -602,7 +702,7 @@ class PromptServer():
|
|||||||
logging.error("Client session is not initialized")
|
logging.error("Client session is not initialized")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return web.json_response(task.result().to_dict())
|
return web.json_response(task.result().to_dict())
|
||||||
@@ -719,6 +819,9 @@ class PromptServer():
|
|||||||
await self.send(*msg)
|
await self.send(*msg)
|
||||||
|
|
||||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||||
|
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||||
|
|
||||||
|
async def start_multi_address(self, addresses, call_on_start=None):
|
||||||
runner = web.AppRunner(self.app, access_log=None)
|
runner = web.AppRunner(self.app, access_log=None)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
ssl_ctx = None
|
ssl_ctx = None
|
||||||
@@ -729,17 +832,26 @@ class PromptServer():
|
|||||||
keyfile=args.tls_keyfile)
|
keyfile=args.tls_keyfile)
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
|
|
||||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
logging.info("Starting server\n")
|
||||||
await site.start()
|
for addr in addresses:
|
||||||
|
address = addr[0]
|
||||||
|
port = addr[1]
|
||||||
|
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
self.address = address
|
if not hasattr(self, 'address'):
|
||||||
self.port = port
|
self.address = address #TODO: remove this
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
if ':' in address:
|
||||||
|
address_print = "[{}]".format(address)
|
||||||
|
else:
|
||||||
|
address_print = address
|
||||||
|
|
||||||
|
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||||
|
|
||||||
if verbose:
|
|
||||||
logging.info("Starting server\n")
|
|
||||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
|
||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(scheme, address, port)
|
call_on_start(scheme, self.address, self.port)
|
||||||
|
|
||||||
def add_on_prompt_handler(self, handler):
|
def add_on_prompt_handler(self, handler):
|
||||||
self.on_prompt_handlers.append(handler)
|
self.on_prompt_handlers.append(handler)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
## Install test dependencies
|
## Install test dependencies
|
||||||
|
|
||||||
`pip install -r tests-units/requirements.txt`
|
`pip install -r tests-unit/requirements.txt`
|
||||||
|
|
||||||
## Run tests
|
## Run tests
|
||||||
`pytest tests-units/`
|
`pytest tests-unit/`
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import pytest
|
import pytest
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from app.frontend_management import (
|
from app.frontend_management import (
|
||||||
FrontendManager,
|
FrontendManager,
|
||||||
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
|
|||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
FrontendManager.init_frontend_unsafe(version_string)
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_os_functions():
|
||||||
|
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||||
|
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||||
|
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||||
|
mock_listdir.return_value = [] # Simulate empty directory
|
||||||
|
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_download():
|
||||||
|
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||||
|
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||||
|
# Arrange
|
||||||
|
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||||
|
version_string = 'test-owner/test-repo@1.0.0'
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_makedirs.assert_called_once()
|
||||||
|
mock_download.assert_called_once()
|
||||||
|
mock_listdir.assert_called_once()
|
||||||
|
mock_rmdir.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_parse_version_string():
|
def test_parse_version_string():
|
||||||
version_string = "owner/repo@1.0.0"
|
version_string = "owner/repo@1.0.0"
|
||||||
|
|||||||
66
tests-unit/comfy_test/folder_path_test.py
Normal file
66
tests-unit/comfy_test/folder_path_test.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
||||||
|
# TODO(yoland): clean up this after I get back down
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_directory_by_type():
|
||||||
|
test_dir = "/test/dir"
|
||||||
|
folder_paths.set_output_directory(test_dir)
|
||||||
|
assert folder_paths.get_directory_by_type("output") == test_dir
|
||||||
|
assert folder_paths.get_directory_by_type("invalid") is None
|
||||||
|
|
||||||
|
def test_annotated_filepath():
|
||||||
|
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
|
||||||
|
|
||||||
|
def test_get_annotated_filepath():
|
||||||
|
default_dir = "/default/dir"
|
||||||
|
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||||
|
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||||
|
|
||||||
|
def test_add_model_folder_path():
|
||||||
|
folder_paths.add_model_folder_path("test_folder", "/test/path")
|
||||||
|
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
||||||
|
|
||||||
|
def test_recursive_search(temp_dir):
|
||||||
|
os.makedirs(os.path.join(temp_dir, "subdir"))
|
||||||
|
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
||||||
|
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
|
||||||
|
|
||||||
|
files, dirs = folder_paths.recursive_search(temp_dir)
|
||||||
|
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
||||||
|
assert len(dirs) == 2 # temp_dir and subdir
|
||||||
|
|
||||||
|
def test_filter_files_extensions():
|
||||||
|
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, []) == files
|
||||||
|
|
||||||
|
@patch("folder_paths.recursive_search")
|
||||||
|
@patch("folder_paths.folder_names_and_paths")
|
||||||
|
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||||
|
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
|
||||||
|
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
||||||
|
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||||
|
|
||||||
|
def test_get_save_image_path(temp_dir):
|
||||||
|
with patch("folder_paths.output_directory", temp_dir):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
||||||
|
assert os.path.samefile(full_output_folder, temp_dir)
|
||||||
|
assert filename == "test"
|
||||||
|
assert counter == 1
|
||||||
|
assert subfolder == ""
|
||||||
|
assert filename_prefix == "test"
|
||||||
0
tests-unit/folder_paths_test/__init__.py
Normal file
0
tests-unit/folder_paths_test/__init__.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from folder_paths import filter_files_content_types
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def file_extensions():
|
||||||
|
return {
|
||||||
|
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||||
|
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||||
|
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mock_dir(file_extensions):
|
||||||
|
with tempfile.TemporaryDirectory() as directory:
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
for extension in extensions:
|
||||||
|
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
|
||||||
|
f.write(f"Sample {content_type} file in {extension} format")
|
||||||
|
yield directory
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_correctly(mock_dir, file_extensions):
|
||||||
|
files = os.listdir(mock_dir)
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
for extension in extensions:
|
||||||
|
assert f"sample_{content_type}.{extension}" in filtered_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_uniquely(mock_dir, file_extensions):
|
||||||
|
files = os.listdir(mock_dir)
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
assert len(filtered_files) == len(extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_bad_extensions():
|
||||||
|
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_no_extension():
|
||||||
|
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_no_files():
|
||||||
|
files = []
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
@@ -1,10 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import tempfile
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientResponse
|
from aiohttp import ClientResponse
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
class AsyncIteratorMock:
|
||||||
"""
|
"""
|
||||||
@@ -42,7 +49,7 @@ class ContentMock:
|
|||||||
return AsyncIteratorMock(self.chunks)
|
return AsyncIteratorMock(self.chunks)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_success():
|
async def test_download_model_success(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
@@ -53,15 +60,13 @@ async def test_download_model_success():
|
|||||||
mock_make_request = AsyncMock(return_value=mock_response)
|
mock_make_request = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
# Mock file operations
|
|
||||||
mock_open = MagicMock()
|
|
||||||
mock_file = MagicMock()
|
|
||||||
mock_open.return_value.__enter__.return_value = mock_file
|
|
||||||
time_values = itertools.count(0, 0.1)
|
time_values = itertools.count(0, 0.1)
|
||||||
|
|
||||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
patch('builtins.open', mock_open), \
|
patch('folder_paths.folder_names_and_paths', fake_paths), \
|
||||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
@@ -69,6 +74,7 @@ async def test_download_model_success():
|
|||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'checkpoints',
|
'checkpoints',
|
||||||
|
temp_dir,
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,44 +89,48 @@ async def test_download_model_success():
|
|||||||
|
|
||||||
# Check initial call
|
# Check initial call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check final call
|
# Check final call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify file writing
|
mock_file_path = os.path.join(temp_dir, 'model.sft')
|
||||||
mock_file.write.assert_any_call(b'a' * 500)
|
assert os.path.exists(mock_file_path)
|
||||||
mock_file.write.assert_any_call(b'b' * 300)
|
with open(mock_file_path, 'rb') as mock_file:
|
||||||
mock_file.write.assert_any_call(b'c' * 200)
|
assert mock_file.read() == b''.join(chunks)
|
||||||
|
os.remove(mock_file_path)
|
||||||
|
|
||||||
# Verify request was made
|
# Verify request was made
|
||||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_url_request_failure():
|
async def test_download_model_url_request_failure(temp_dir):
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
mock_response = AsyncMock(spec=ClientResponse)
|
mock_response = AsyncMock(spec=ClientResponse)
|
||||||
mock_response.status = 404 # Simulate a "Not Found" error
|
mock_response.status = 404 # Simulate a "Not Found" error
|
||||||
mock_get = AsyncMock(return_value=mock_response)
|
mock_get = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
# Mock the create_model_path function
|
# Mock the create_model_path function
|
||||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
patch('folder_paths.folder_names_and_paths', fake_paths):
|
||||||
# Call the function
|
# Call the function
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_get,
|
mock_get,
|
||||||
'model.safetensors',
|
'model.safetensors',
|
||||||
'http://example.com/model.safetensors',
|
'http://example.com/model.safetensors',
|
||||||
'mock_directory',
|
'checkpoints',
|
||||||
mock_progress_callback
|
temp_dir,
|
||||||
)
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
# Assert the expected behavior
|
# Assert the expected behavior
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
@@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
# Check that progress_callback was called with the correct arguments
|
# Check that progress_callback was called with the correct arguments
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.PENDING,
|
status=DownloadStatusType.PENDING,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
mock_progress_callback.assert_called_with(
|
mock_progress_callback.assert_called_with(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.ERROR,
|
status=DownloadStatusType.ERROR,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@@ -153,40 +163,59 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_invalid_model_subdirectory():
|
async def test_download_model_invalid_model_subdirectory():
|
||||||
|
|
||||||
mock_make_request = AsyncMock()
|
mock_make_request = AsyncMock()
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_make_request,
|
mock_make_request,
|
||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'../bad_path',
|
'../bad_path',
|
||||||
|
'../bad_path',
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
assert result.message == 'Invalid model subdirectory'
|
assert result.message.startswith('Invalid or unrecognized model directory')
|
||||||
assert result.status == 'error'
|
assert result.status == 'error'
|
||||||
assert result.already_existed is False
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_invalid_folder_path():
|
||||||
|
mock_make_request = AsyncMock()
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'checkpoints',
|
||||||
|
'invalid_path',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message.startswith("Invalid folder path")
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
# For create_model_path function
|
|
||||||
def test_create_model_path(tmp_path, monkeypatch):
|
def test_create_model_path(tmp_path, monkeypatch):
|
||||||
mock_models_dir = tmp_path / "models"
|
model_name = "model.safetensors"
|
||||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
folder_path = os.path.join(tmp_path, "mock_dir")
|
||||||
|
|
||||||
model_name = "test_model.sft"
|
file_path = create_model_path(model_name, folder_path)
|
||||||
model_directory = "test_dir"
|
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
assert file_path == os.path.join(folder_path, "model.safetensors")
|
||||||
|
|
||||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
|
||||||
assert relative_path == f"{model_directory}/{model_name}"
|
|
||||||
assert os.path.exists(os.path.dirname(file_path))
|
assert os.path.exists(os.path.dirname(file_path))
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("../path_traversal.safetensors", folder_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("/etc/some_root_path", folder_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||||
@@ -195,7 +224,7 @@ async def test_check_file_exists_when_file_exists(tmp_path):
|
|||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
@@ -203,7 +232,7 @@ async def test_check_file_exists_when_file_exists(tmp_path):
|
|||||||
assert result.already_existed is True
|
assert result.already_existed is True
|
||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
mock_callback.assert_called_once_with(
|
||||||
"test/existing_model.sft",
|
"existing_model.sft",
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -213,38 +242,46 @@ async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
|||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_callback.assert_not_called()
|
mock_callback.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_no_content_length():
|
async def test_track_download_progress_no_content_length(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {} # No Content-Length header
|
mock_response.headers = {} # No Content-Length header
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
chunks = [b'a' * 500, b'b' * 500]
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
|
||||||
|
|
||||||
with patch('builtins.open', mock_open):
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
result = await track_download_progress(
|
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
result = await track_download_progress(
|
||||||
mock_callback, 'models/model.sft', interval=0.1
|
mock_response, full_path, 'model.sft',
|
||||||
)
|
mock_callback, interval=0.1
|
||||||
|
)
|
||||||
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
assert os.path.exists(full_path)
|
||||||
|
with open(full_path, 'rb') as f:
|
||||||
|
assert f.read() == b''.join(chunks)
|
||||||
|
os.remove(full_path)
|
||||||
|
|
||||||
# Check that progress was reported even without knowing the total size
|
# Check that progress was reported even without knowing the total size
|
||||||
mock_callback.assert_any_call(
|
mock_callback.assert_any_call(
|
||||||
'models/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_interval():
|
async def test_track_download_progress_interval(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
chunks = [b'a' * 100] * 10
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
@@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
|
|||||||
mock_time = MagicMock()
|
mock_time = MagicMock()
|
||||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||||
|
|
||||||
with patch('builtins.open', mock_open), \
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
patch('time.time', mock_time):
|
|
||||||
|
with patch('time.time', mock_time):
|
||||||
await track_download_progress(
|
await track_download_progress(
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
mock_response, full_path, 'model.sft',
|
||||||
mock_callback, 'models/model.sft', interval=1.0
|
mock_callback, interval=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print out the actual call count and the arguments of each call for debugging
|
assert os.path.exists(full_path)
|
||||||
print(f"mock_callback was called {mock_callback.call_count} times")
|
with open(full_path, 'rb') as f:
|
||||||
for i, call in enumerate(mock_callback.call_args_list):
|
assert f.read() == b''.join(chunks)
|
||||||
args, kwargs = call
|
os.remove(full_path)
|
||||||
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
|
||||||
|
|
||||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||||
@@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
|
|||||||
assert last_call[0][1].status == "completed"
|
assert last_call[0][1].status == "completed"
|
||||||
assert last_call[0][1].progress_percentage == 100
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
def test_valid_subdirectory():
|
|
||||||
assert validate_model_subdirectory("valid-model123") is True
|
|
||||||
|
|
||||||
def test_subdirectory_too_long():
|
|
||||||
assert validate_model_subdirectory("a" * 51) is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_double_dots():
|
|
||||||
assert validate_model_subdirectory("model/../unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_slash():
|
|
||||||
assert validate_model_subdirectory("model/unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_special_characters():
|
|
||||||
assert validate_model_subdirectory("model@unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_underscore_and_dash():
|
|
||||||
assert validate_model_subdirectory("valid_model-name") is True
|
|
||||||
|
|
||||||
def test_empty_subdirectory():
|
|
||||||
assert validate_model_subdirectory("") is False
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("filename, expected", [
|
@pytest.mark.parametrize("filename, expected", [
|
||||||
("valid_model.safetensors", True),
|
("valid_model.safetensors", True),
|
||||||
("valid_model.sft", True),
|
("valid_model.sft", True),
|
||||||
|
|||||||
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from aiohttp import web
|
||||||
|
from app.user_manager import UserManager
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
pytestmark = (
|
||||||
|
pytest.mark.asyncio
|
||||||
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_manager(tmp_path):
|
||||||
|
um = UserManager()
|
||||||
|
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||||
|
tmp_path, file
|
||||||
|
)
|
||||||
|
return um
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(user_manager):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
user_manager.add_routes(routes)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == ["file1.txt"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&full_info=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["path"] == "file1.txt"
|
||||||
|
assert "size" in result[0]
|
||||||
|
assert "modified" in result[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == [
|
||||||
|
["subdir/file1.txt", "subdir", "file1.txt"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=")
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||||
|
os_sep = "\\"
|
||||||
|
with patch("os.sep", os_sep):
|
||||||
|
with patch("os.path.sep", os_sep):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0] # Ensure backslash is not present
|
||||||
|
assert result[0] == "subdir/file1.txt"
|
||||||
|
|
||||||
|
# Test with full_info
|
||||||
|
resp = await client.get(
|
||||||
|
"/userdata?dir=test_dir&recurse=true&full_info=true"
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||||
|
assert result[0]["path"] == "subdir/file1.txt"
|
||||||
126
tests-unit/utils/extra_config_test.py
Normal file
126
tests-unit/utils/extra_config_test.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch, mock_open
|
||||||
|
|
||||||
|
from utils.extra_config import load_extra_path_config
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_content():
|
||||||
|
return {
|
||||||
|
'test_config': {
|
||||||
|
'base_path': '~/App/',
|
||||||
|
'checkpoints': 'subfolder1',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expanded_home():
|
||||||
|
return '/home/user'
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yaml_config_with_appdata():
|
||||||
|
return """
|
||||||
|
test_config:
|
||||||
|
base_path: '%APPDATA%/ComfyUI'
|
||||||
|
checkpoints: 'models/checkpoints'
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||||
|
return yaml.safe_load(yaml_config_with_appdata)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expandvars_appdata():
|
||||||
|
mock = Mock()
|
||||||
|
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||||
|
return mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_add_model_folder_path():
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expanduser(mock_expanded_home):
|
||||||
|
def _expanduser(path):
|
||||||
|
if path.startswith('~/'):
|
||||||
|
return os.path.join(mock_expanded_home, path[2:])
|
||||||
|
return path
|
||||||
|
return _expanduser
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_safe_load(mock_yaml_content):
|
||||||
|
return Mock(return_value=mock_yaml_content)
|
||||||
|
|
||||||
|
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||||
|
def test_load_extra_model_paths_expands_userpath(
|
||||||
|
mock_file,
|
||||||
|
monkeypatch,
|
||||||
|
mock_add_model_folder_path,
|
||||||
|
mock_expanduser,
|
||||||
|
mock_yaml_safe_load,
|
||||||
|
mock_expanded_home
|
||||||
|
):
|
||||||
|
# Attach mocks used by load_extra_path_config
|
||||||
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
|
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
|
||||||
|
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
|
||||||
|
|
||||||
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
|
expected_calls = [
|
||||||
|
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
|
# Check if add_model_folder_path was called with the correct arguments
|
||||||
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
|
assert actual_call.args[0] == expected_call[0]
|
||||||
|
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
|
||||||
|
assert actual_call.args[2] == expected_call[2]
|
||||||
|
|
||||||
|
# Check if yaml.safe_load was called
|
||||||
|
mock_yaml_safe_load.assert_called_once()
|
||||||
|
|
||||||
|
# Check if open was called with the correct file path
|
||||||
|
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||||
|
|
||||||
|
@patch('builtins.open', new_callable=mock_open)
|
||||||
|
def test_load_extra_model_paths_expands_appdata(
|
||||||
|
mock_file,
|
||||||
|
monkeypatch,
|
||||||
|
mock_add_model_folder_path,
|
||||||
|
mock_expandvars_appdata,
|
||||||
|
yaml_config_with_appdata,
|
||||||
|
mock_yaml_content_appdata
|
||||||
|
):
|
||||||
|
# Set the mock_file to return yaml with appdata as a variable
|
||||||
|
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
||||||
|
|
||||||
|
# Attach mocks
|
||||||
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
|
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
|
||||||
|
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
|
||||||
|
|
||||||
|
# Mock expanduser to do nothing (since we're not testing it here)
|
||||||
|
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
|
||||||
|
|
||||||
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
|
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||||
|
expected_calls = [
|
||||||
|
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
|
# Check the base path variable was expanded
|
||||||
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
|
assert actual_call.args == expected_call
|
||||||
|
|
||||||
|
# Verify that expandvars was called
|
||||||
|
assert mock_expandvars_appdata.called
|
||||||
@@ -95,17 +95,16 @@ class ComfyClient:
|
|||||||
pass # Probably want to store this off for testing
|
pass # Probably want to store this off for testing
|
||||||
|
|
||||||
history = self.get_history(prompt_id)[prompt_id]
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
for node_id in history['outputs']:
|
node_output = history['outputs'][node_id]
|
||||||
node_output = history['outputs'][node_id]
|
result.outputs[node_id] = node_output
|
||||||
result.outputs[node_id] = node_output
|
images_output = []
|
||||||
if 'images' in node_output:
|
if 'images' in node_output:
|
||||||
images_output = []
|
for image in node_output['images']:
|
||||||
for image in node_output['images']:
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
image_obj = Image.open(BytesIO(image_data))
|
||||||
image_obj = Image.open(BytesIO(image_data))
|
images_output.append(image_obj)
|
||||||
images_output.append(image_obj)
|
node_output['image_objects'] = images_output
|
||||||
node_output['image_objects'] = images_output
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -357,6 +356,25 @@ class TestExecution:
|
|||||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||||
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
||||||
|
|
||||||
|
def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||||
|
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||||
|
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
|
||||||
|
# We have multiple outputs. The first is invalid, but the second is valid
|
||||||
|
g.node("SaveImage", images=mix1.out(0))
|
||||||
|
g.node("SaveImage", images=mix2.out(0))
|
||||||
|
g.remove_node("removeme")
|
||||||
|
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
|
# Add back in the missing node to make sure the error doesn't break the server
|
||||||
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
# Creating the nodes in this specific order previously caused a bug
|
# Creating the nodes in this specific order previously caused a bug
|
||||||
@@ -450,8 +468,8 @@ class TestExecution:
|
|||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
output1 = g.node("PreviewImage", images=input1.out(0))
|
output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
output2 = g.node("PreviewImage", images=input1.out(0))
|
output2 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = client.run(g)
|
||||||
images1 = result.get_images(output1)
|
images1 = result.get_images(output1)
|
||||||
@@ -478,3 +496,29 @@ class TestExecution:
|
|||||||
assert len(images) == 1, "Should have 1 image"
|
assert len(images) == 1, "Should have 1 image"
|
||||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|
||||||
|
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||||
|
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||||
|
# only that one entry in the list is blocked.
|
||||||
|
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
|
||||||
|
int1 = g.node("StubInt", value=1)
|
||||||
|
int2 = g.node("StubInt", value=2)
|
||||||
|
int3 = g.node("StubInt", value=3)
|
||||||
|
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||||
|
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||||
|
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||||
|
|
||||||
|
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||||
|
output = g.node("PreviewImage", images=list_output.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
assert result.did_run(output), "The execution should have run"
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 2, "Should have 2 images"
|
||||||
|
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||||
|
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||||
|
|||||||
@@ -109,15 +109,14 @@ class ComfyClient:
|
|||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = self.get_history(prompt_id)[prompt_id]
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
for node_id in history['outputs']:
|
node_output = history['outputs'][node_id]
|
||||||
node_output = history['outputs'][node_id]
|
images_output = []
|
||||||
if 'images' in node_output:
|
if 'images' in node_output:
|
||||||
images_output = []
|
for image in node_output['images']:
|
||||||
for image in node_output['images']:
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
images_output.append(image_data)
|
||||||
images_output.append(image_data)
|
output_images[node_id] = images_output
|
||||||
output_images[node_id] = images_output
|
|
||||||
|
|
||||||
return output_images
|
return output_images
|
||||||
|
|
||||||
|
|||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
28
utils/extra_config.py
Normal file
28
utils/extra_config.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import folder_paths
|
||||||
|
import logging
|
||||||
|
|
||||||
|
def load_extra_path_config(yaml_path):
|
||||||
|
with open(yaml_path, 'r') as stream:
|
||||||
|
config = yaml.safe_load(stream)
|
||||||
|
for c in config:
|
||||||
|
conf = config[c]
|
||||||
|
if conf is None:
|
||||||
|
continue
|
||||||
|
base_path = None
|
||||||
|
if "base_path" in conf:
|
||||||
|
base_path = conf.pop("base_path")
|
||||||
|
base_path = os.path.expandvars(os.path.expanduser(base_path))
|
||||||
|
is_default = False
|
||||||
|
if "is_default" in conf:
|
||||||
|
is_default = conf.pop("is_default")
|
||||||
|
for x in conf:
|
||||||
|
for y in conf[x].split("\n"):
|
||||||
|
if len(y) == 0:
|
||||||
|
continue
|
||||||
|
full_path = y
|
||||||
|
if base_path is not None:
|
||||||
|
full_path = os.path.join(base_path, full_path)
|
||||||
|
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||||
|
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||||
1
web/assets/CREDIT.txt
generated
vendored
Normal file
1
web/assets/CREDIT.txt
generated
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.
|
||||||
792
web/assets/GraphView-BGt8GmeB.css
generated
vendored
Normal file
792
web/assets/GraphView-BGt8GmeB.css
generated
vendored
Normal file
@@ -0,0 +1,792 @@
|
|||||||
|
|
||||||
|
.editable-text[data-v-54da6fc9] {
|
||||||
|
display: inline;
|
||||||
|
}
|
||||||
|
.editable-text input[data-v-54da6fc9] {
|
||||||
|
width: 100%;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
.group-title-editor.node-title-editor[data-v-fc3f26e3] {
|
||||||
|
z-index: 9999;
|
||||||
|
padding: 0.25rem;
|
||||||
|
}
|
||||||
|
[data-v-fc3f26e3] .editable-text {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
[data-v-fc3f26e3] .editable-text input {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
/* Override the default font size */
|
||||||
|
font-size: inherit;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-bar-button-icon {
|
||||||
|
font-size: var(--sidebar-icon-size) !important;
|
||||||
|
}
|
||||||
|
.side-bar-button-selected .side-bar-button-icon {
|
||||||
|
font-size: var(--sidebar-icon-size) !important;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-bar-button[data-v-caa3ee9c] {
|
||||||
|
width: var(--sidebar-width);
|
||||||
|
height: var(--sidebar-width);
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
||||||
|
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
||||||
|
border-left: 4px solid var(--p-button-text-primary-color);
|
||||||
|
}
|
||||||
|
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
||||||
|
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
||||||
|
border-right: 4px solid var(--p-button-text-primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
:root {
|
||||||
|
--sidebar-width: 64px;
|
||||||
|
--sidebar-icon-size: 1.5rem;
|
||||||
|
}
|
||||||
|
:root .small-sidebar {
|
||||||
|
--sidebar-width: 40px;
|
||||||
|
--sidebar-icon-size: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-tool-bar-container[data-v-4da64512] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
|
||||||
|
pointer-events: auto;
|
||||||
|
|
||||||
|
width: var(--sidebar-width);
|
||||||
|
height: 100%;
|
||||||
|
|
||||||
|
background-color: var(--comfy-menu-bg);
|
||||||
|
color: var(--fg-color);
|
||||||
|
}
|
||||||
|
.side-tool-bar-end[data-v-4da64512] {
|
||||||
|
align-self: flex-end;
|
||||||
|
margin-top: auto;
|
||||||
|
}
|
||||||
|
.sidebar-content-container[data-v-4da64512] {
|
||||||
|
height: 100%;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.p-splitter-gutter {
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.gutter-hidden {
|
||||||
|
display: none !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-bar-panel[data-v-b9df3042] {
|
||||||
|
background-color: var(--bg-color);
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.splitter-overlay[data-v-b9df3042] {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
background-color: transparent;
|
||||||
|
pointer-events: none;
|
||||||
|
/* Set it the same as the ComfyUI menu */
|
||||||
|
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
|
||||||
|
999 should be sufficient to make sure splitter overlays on node's DOM
|
||||||
|
widgets */
|
||||||
|
z-index: 999;
|
||||||
|
border: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
._content[data-v-e7b35fd9] {
|
||||||
|
|
||||||
|
display: flex;
|
||||||
|
|
||||||
|
flex-direction: column
|
||||||
|
}
|
||||||
|
._content[data-v-e7b35fd9] > :not([hidden]) ~ :not([hidden]) {
|
||||||
|
|
||||||
|
--tw-space-y-reverse: 0;
|
||||||
|
|
||||||
|
margin-top: calc(0.5rem * calc(1 - var(--tw-space-y-reverse)));
|
||||||
|
|
||||||
|
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse))
|
||||||
|
}
|
||||||
|
._footer[data-v-e7b35fd9] {
|
||||||
|
|
||||||
|
display: flex;
|
||||||
|
|
||||||
|
flex-direction: column;
|
||||||
|
|
||||||
|
align-items: flex-end;
|
||||||
|
|
||||||
|
padding-top: 1rem
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-37f672ab] .highlight {
|
||||||
|
background-color: var(--p-primary-color);
|
||||||
|
color: var(--p-primary-contrast-color);
|
||||||
|
font-weight: bold;
|
||||||
|
border-radius: 0.25rem;
|
||||||
|
padding: 0rem 0.125rem;
|
||||||
|
margin: -0.125rem 0.125rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.slot_row[data-v-ff07c900] {
|
||||||
|
padding: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Original N-Sidebar styles */
|
||||||
|
._sb_dot[data-v-ff07c900] {
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: grey;
|
||||||
|
}
|
||||||
|
.node_header[data-v-ff07c900] {
|
||||||
|
line-height: 1;
|
||||||
|
padding: 8px 13px 7px;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
font-size: 15px;
|
||||||
|
text-wrap: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
.headdot[data-v-ff07c900] {
|
||||||
|
width: 10px;
|
||||||
|
height: 10px;
|
||||||
|
float: inline-start;
|
||||||
|
margin-right: 8px;
|
||||||
|
}
|
||||||
|
.IMAGE[data-v-ff07c900] {
|
||||||
|
background-color: #64b5f6;
|
||||||
|
}
|
||||||
|
.VAE[data-v-ff07c900] {
|
||||||
|
background-color: #ff6e6e;
|
||||||
|
}
|
||||||
|
.LATENT[data-v-ff07c900] {
|
||||||
|
background-color: #ff9cf9;
|
||||||
|
}
|
||||||
|
.MASK[data-v-ff07c900] {
|
||||||
|
background-color: #81c784;
|
||||||
|
}
|
||||||
|
.CONDITIONING[data-v-ff07c900] {
|
||||||
|
background-color: #ffa931;
|
||||||
|
}
|
||||||
|
.CLIP[data-v-ff07c900] {
|
||||||
|
background-color: #ffd500;
|
||||||
|
}
|
||||||
|
.MODEL[data-v-ff07c900] {
|
||||||
|
background-color: #b39ddb;
|
||||||
|
}
|
||||||
|
.CONTROL_NET[data-v-ff07c900] {
|
||||||
|
background-color: #a5d6a7;
|
||||||
|
}
|
||||||
|
._sb_node_preview[data-v-ff07c900] {
|
||||||
|
background-color: var(--comfy-menu-bg);
|
||||||
|
font-family: 'Open Sans', sans-serif;
|
||||||
|
font-size: small;
|
||||||
|
color: var(--descrip-text);
|
||||||
|
border: 1px solid var(--descrip-text);
|
||||||
|
min-width: 300px;
|
||||||
|
width: -moz-min-content;
|
||||||
|
width: min-content;
|
||||||
|
height: -moz-fit-content;
|
||||||
|
height: fit-content;
|
||||||
|
z-index: 9999;
|
||||||
|
border-radius: 12px;
|
||||||
|
overflow: hidden;
|
||||||
|
font-size: 12px;
|
||||||
|
padding-bottom: 10px;
|
||||||
|
}
|
||||||
|
._sb_node_preview ._sb_description[data-v-ff07c900] {
|
||||||
|
margin: 10px;
|
||||||
|
padding: 6px;
|
||||||
|
background: var(--border-color);
|
||||||
|
border-radius: 5px;
|
||||||
|
font-style: italic;
|
||||||
|
font-weight: 500;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
word-break: break-word;
|
||||||
|
}
|
||||||
|
._sb_table[data-v-ff07c900] {
|
||||||
|
display: grid;
|
||||||
|
|
||||||
|
grid-column-gap: 10px;
|
||||||
|
/* Spazio tra le colonne */
|
||||||
|
width: 100%;
|
||||||
|
/* Imposta la larghezza della tabella al 100% del contenitore */
|
||||||
|
}
|
||||||
|
._sb_row[data-v-ff07c900] {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 10px 1fr 1fr 1fr 10px;
|
||||||
|
grid-column-gap: 10px;
|
||||||
|
align-items: center;
|
||||||
|
padding-left: 9px;
|
||||||
|
padding-right: 9px;
|
||||||
|
}
|
||||||
|
._sb_row_string[data-v-ff07c900] {
|
||||||
|
grid-template-columns: 10px 1fr 1fr 10fr 1fr;
|
||||||
|
}
|
||||||
|
._sb_col[data-v-ff07c900] {
|
||||||
|
border: 0px solid #000;
|
||||||
|
display: flex;
|
||||||
|
align-items: flex-end;
|
||||||
|
flex-direction: row-reverse;
|
||||||
|
flex-wrap: nowrap;
|
||||||
|
align-content: flex-start;
|
||||||
|
justify-content: flex-end;
|
||||||
|
}
|
||||||
|
._sb_inherit[data-v-ff07c900] {
|
||||||
|
display: inherit;
|
||||||
|
}
|
||||||
|
._long_field[data-v-ff07c900] {
|
||||||
|
background: var(--bg-color);
|
||||||
|
border: 2px solid var(--border-color);
|
||||||
|
margin: 5px 5px 0 5px;
|
||||||
|
border-radius: 10px;
|
||||||
|
line-height: 1.7;
|
||||||
|
text-wrap: nowrap;
|
||||||
|
}
|
||||||
|
._sb_arrow[data-v-ff07c900] {
|
||||||
|
color: var(--fg-color);
|
||||||
|
}
|
||||||
|
._sb_preview_badge[data-v-ff07c900] {
|
||||||
|
text-align: center;
|
||||||
|
background: var(--comfy-input-bg);
|
||||||
|
font-weight: bold;
|
||||||
|
color: var(--error-text);
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-vue-node-search-container[data-v-2d409367] {
|
||||||
|
display: flex;
|
||||||
|
width: 100%;
|
||||||
|
min-width: 26rem;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-search-container[data-v-2d409367] * {
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-preview-container[data-v-2d409367] {
|
||||||
|
position: absolute;
|
||||||
|
left: -350px;
|
||||||
|
top: 50px;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-search-box[data-v-2d409367] {
|
||||||
|
z-index: 10;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
._filter-button[data-v-2d409367] {
|
||||||
|
z-index: 10;
|
||||||
|
}
|
||||||
|
._dialog[data-v-2d409367] {
|
||||||
|
min-width: 26rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.invisible-dialog-root {
|
||||||
|
width: 60%;
|
||||||
|
min-width: 24rem;
|
||||||
|
max-width: 48rem;
|
||||||
|
border: 0 !important;
|
||||||
|
background-color: transparent !important;
|
||||||
|
margin-top: 25vh;
|
||||||
|
margin-left: 400px;
|
||||||
|
}
|
||||||
|
@media all and (max-width: 768px) {
|
||||||
|
.invisible-dialog-root {
|
||||||
|
margin-left: 0px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.node-search-box-dialog-mask {
|
||||||
|
align-items: flex-start !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-tooltip[data-v-0a4402f9] {
|
||||||
|
background: var(--comfy-input-bg);
|
||||||
|
border-radius: 5px;
|
||||||
|
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||||
|
color: var(--input-text);
|
||||||
|
font-family: sans-serif;
|
||||||
|
left: 0;
|
||||||
|
max-width: 30vw;
|
||||||
|
padding: 4px 8px;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
transform: translate(5px, calc(-100% - 5px));
|
||||||
|
white-space: pre-wrap;
|
||||||
|
z-index: 99999;
|
||||||
|
}
|
||||||
|
|
||||||
|
.p-buttongroup-vertical[data-v-ce8bd6ac] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
border-radius: var(--p-button-border-radius);
|
||||||
|
overflow: hidden;
|
||||||
|
border: 1px solid var(--p-panel-border-color);
|
||||||
|
}
|
||||||
|
.p-buttongroup-vertical .p-button[data-v-ce8bd6ac] {
|
||||||
|
margin: 0;
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-image-wrap[data-v-9bc23daf] {
|
||||||
|
display: contents;
|
||||||
|
}
|
||||||
|
.comfy-image-blur[data-v-9bc23daf] {
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
-o-object-fit: cover;
|
||||||
|
object-fit: cover;
|
||||||
|
}
|
||||||
|
.comfy-image-main[data-v-9bc23daf] {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
-o-object-fit: cover;
|
||||||
|
object-fit: cover;
|
||||||
|
-o-object-position: center;
|
||||||
|
object-position: center;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
.contain .comfy-image-wrap[data-v-9bc23daf] {
|
||||||
|
position: relative;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
.contain .comfy-image-main[data-v-9bc23daf] {
|
||||||
|
-o-object-fit: contain;
|
||||||
|
object-fit: contain;
|
||||||
|
-webkit-backdrop-filter: blur(10px);
|
||||||
|
backdrop-filter: blur(10px);
|
||||||
|
position: absolute;
|
||||||
|
}
|
||||||
|
.broken-image-placeholder[data-v-9bc23daf] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
margin: 2rem;
|
||||||
|
}
|
||||||
|
.broken-image-placeholder i[data-v-9bc23daf] {
|
||||||
|
font-size: 3rem;
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-container[data-v-d9c060ae] {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
aspect-ratio: 1 / 1;
|
||||||
|
overflow: hidden;
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
.image-preview-mask[data-v-d9c060ae] {
|
||||||
|
position: absolute;
|
||||||
|
left: 50%;
|
||||||
|
top: 50%;
|
||||||
|
transform: translate(-50%, -50%);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
opacity: 0;
|
||||||
|
transition: opacity 0.3s ease;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
.result-container:hover .image-preview-mask[data-v-d9c060ae] {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.task-result-preview[data-v-d4c8a1fe] {
|
||||||
|
aspect-ratio: 1 / 1;
|
||||||
|
overflow: hidden;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
.task-result-preview i[data-v-d4c8a1fe],
|
||||||
|
.task-result-preview span[data-v-d4c8a1fe] {
|
||||||
|
font-size: 2rem;
|
||||||
|
}
|
||||||
|
.task-item[data-v-d4c8a1fe] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow: hidden;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
.task-item-details[data-v-d4c8a1fe] {
|
||||||
|
position: absolute;
|
||||||
|
bottom: 0;
|
||||||
|
padding: 0.6rem;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
width: 100%;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
.task-node-link[data-v-d4c8a1fe] {
|
||||||
|
padding: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* In dark mode, transparent background color for tags is not ideal for tags that
|
||||||
|
are floating on top of images. */
|
||||||
|
.tag-wrapper[data-v-d4c8a1fe] {
|
||||||
|
background-color: var(--p-primary-contrast-color);
|
||||||
|
border-radius: 6px;
|
||||||
|
display: inline-flex;
|
||||||
|
}
|
||||||
|
.node-name-tag[data-v-d4c8a1fe] {
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
.status-tag-group[data-v-d4c8a1fe] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
.progress-preview-img[data-v-d4c8a1fe] {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
-o-object-fit: cover;
|
||||||
|
object-fit: cover;
|
||||||
|
-o-object-position: center;
|
||||||
|
object-position: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* PrimeVue's galleria teleports the fullscreen gallery out of subtree so we
|
||||||
|
cannot use scoped style here. */
|
||||||
|
img.galleria-image {
|
||||||
|
max-width: 100vw;
|
||||||
|
max-height: 100vh;
|
||||||
|
-o-object-fit: contain;
|
||||||
|
object-fit: contain;
|
||||||
|
}
|
||||||
|
.p-galleria-close-button {
|
||||||
|
/* Set z-index so the close button doesn't get hidden behind the image when image is large */
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-vue-side-bar-container[data-v-1b0a8fe3] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
height: 100%;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-header[data-v-1b0a8fe3] {
|
||||||
|
flex-shrink: 0;
|
||||||
|
border-left: none;
|
||||||
|
border-right: none;
|
||||||
|
border-top: none;
|
||||||
|
border-radius: 0;
|
||||||
|
padding: 0.25rem 1rem;
|
||||||
|
min-height: 2.5rem;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-header-span[data-v-1b0a8fe3] {
|
||||||
|
font-size: small;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-body[data-v-1b0a8fe3] {
|
||||||
|
flex-grow: 1;
|
||||||
|
overflow: auto;
|
||||||
|
scrollbar-width: thin;
|
||||||
|
scrollbar-color: transparent transparent;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar {
|
||||||
|
width: 1px;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar-thumb {
|
||||||
|
background-color: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
.scroll-container[data-v-08fa89b1] {
|
||||||
|
height: 100%;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
.queue-grid[data-v-08fa89b1] {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
|
||||||
|
padding: 0.5rem;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tree-node[data-v-633e27ab] {
|
||||||
|
width: 100%;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
}
|
||||||
|
.leaf-count-badge[data-v-633e27ab] {
|
||||||
|
margin-left: 0.5rem;
|
||||||
|
}
|
||||||
|
.node-content[data-v-633e27ab] {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
.leaf-label[data-v-633e27ab] {
|
||||||
|
margin-left: 0.5rem;
|
||||||
|
}
|
||||||
|
[data-v-633e27ab] .editable-text span {
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-bd7bae90] .tree-explorer-node-label {
|
||||||
|
width: 100%;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
margin-left: var(--p-tree-node-gap);
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* The following styles are necessary to avoid layout shift when dragging nodes over folders.
|
||||||
|
* By setting the position to relative on the parent and using an absolutely positioned pseudo-element,
|
||||||
|
* we can create a visual indicator for the drop target without affecting the layout of other elements.
|
||||||
|
*/
|
||||||
|
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder) {
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder.can-drop)::after {
|
||||||
|
content: '';
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
bottom: 0;
|
||||||
|
border: 1px solid var(--p-content-color);
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-lib-node-container[data-v-90dfee08] {
|
||||||
|
height: 100%;
|
||||||
|
width: 100%
|
||||||
|
}
|
||||||
|
|
||||||
|
.p-selectbutton .p-button[data-v-91077f2a] {
|
||||||
|
padding: 0.5rem;
|
||||||
|
}
|
||||||
|
.p-selectbutton .p-button .pi[data-v-91077f2a] {
|
||||||
|
font-size: 1.5rem;
|
||||||
|
}
|
||||||
|
.field[data-v-91077f2a] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
.color-picker-container[data-v-91077f2a] {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-lib-filter-popup {
|
||||||
|
margin-left: -13px;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-f6a7371a] .comfy-vue-side-bar-body {
|
||||||
|
background: var(--p-tree-background);
|
||||||
|
}
|
||||||
|
[data-v-f6a7371a] .node-lib-bookmark-tree-explorer {
|
||||||
|
padding-bottom: 2px;
|
||||||
|
}
|
||||||
|
[data-v-f6a7371a] .p-divider {
|
||||||
|
margin: var(--comfy-tree-explorer-item-padding) 0px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model_preview[data-v-32e6c4d9] {
|
||||||
|
background-color: var(--comfy-menu-bg);
|
||||||
|
font-family: 'Open Sans', sans-serif;
|
||||||
|
color: var(--descrip-text);
|
||||||
|
border: 1px solid var(--descrip-text);
|
||||||
|
min-width: 300px;
|
||||||
|
max-width: 500px;
|
||||||
|
width: -moz-fit-content;
|
||||||
|
width: fit-content;
|
||||||
|
height: -moz-fit-content;
|
||||||
|
height: fit-content;
|
||||||
|
z-index: 9999;
|
||||||
|
border-radius: 12px;
|
||||||
|
overflow: hidden;
|
||||||
|
font-size: 12px;
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
.model_preview_image[data-v-32e6c4d9] {
|
||||||
|
margin: auto;
|
||||||
|
width: -moz-fit-content;
|
||||||
|
width: fit-content;
|
||||||
|
}
|
||||||
|
.model_preview_image img[data-v-32e6c4d9] {
|
||||||
|
max-width: 100%;
|
||||||
|
max-height: 150px;
|
||||||
|
-o-object-fit: contain;
|
||||||
|
object-fit: contain;
|
||||||
|
}
|
||||||
|
.model_preview_title[data-v-32e6c4d9] {
|
||||||
|
font-weight: bold;
|
||||||
|
text-align: center;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
.model_preview_top_container[data-v-32e6c4d9] {
|
||||||
|
text-align: center;
|
||||||
|
line-height: 0.5;
|
||||||
|
}
|
||||||
|
.model_preview_filename[data-v-32e6c4d9],
|
||||||
|
.model_preview_author[data-v-32e6c4d9],
|
||||||
|
.model_preview_architecture[data-v-32e6c4d9] {
|
||||||
|
display: inline-block;
|
||||||
|
text-align: center;
|
||||||
|
margin: 5px;
|
||||||
|
font-size: 10px;
|
||||||
|
}
|
||||||
|
.model_preview_prefix[data-v-32e6c4d9] {
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-lib-model-icon-container[data-v-70b69131] {
|
||||||
|
display: inline-block;
|
||||||
|
position: relative;
|
||||||
|
left: 0;
|
||||||
|
height: 1.5rem;
|
||||||
|
vertical-align: top;
|
||||||
|
width: 0px;
|
||||||
|
}
|
||||||
|
.model-lib-model-icon[data-v-70b69131] {
|
||||||
|
background-size: cover;
|
||||||
|
background-position: center;
|
||||||
|
display: inline-block;
|
||||||
|
position: relative;
|
||||||
|
left: -2.5rem;
|
||||||
|
height: 2rem;
|
||||||
|
width: 2rem;
|
||||||
|
vertical-align: top;
|
||||||
|
}
|
||||||
|
|
||||||
|
.pi-fake-spacer {
|
||||||
|
height: 1px;
|
||||||
|
width: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-74b01bce] .comfy-vue-side-bar-body {
|
||||||
|
background: var(--p-tree-background);
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-d2d58252] .comfy-vue-side-bar-body {
|
||||||
|
background: var(--p-tree-background);
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-84e785b8] .p-togglebutton::before {
|
||||||
|
display: none
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton {
|
||||||
|
position: relative;
|
||||||
|
flex-shrink: 0;
|
||||||
|
border-radius: 0px;
|
||||||
|
background-color: transparent;
|
||||||
|
padding-left: 0.5rem;
|
||||||
|
padding-right: 0.5rem
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
|
||||||
|
border-bottom-width: 2px;
|
||||||
|
border-bottom-color: var(--p-button-text-primary-color)
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
|
||||||
|
visibility: visible
|
||||||
|
}
|
||||||
|
.status-indicator[data-v-84e785b8] {
|
||||||
|
position: absolute;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
top: 50%;
|
||||||
|
left: 50%;
|
||||||
|
transform: translate(-50%, -50%)
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
|
||||||
|
display: none
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton .close-button {
|
||||||
|
visibility: hidden
|
||||||
|
}
|
||||||
|
|
||||||
|
.top-menubar[data-v-2ec1b620] .p-menubar-item-link svg {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
[data-v-2ec1b620] .p-menubar-submenu.dropdown-direction-up {
|
||||||
|
top: auto;
|
||||||
|
bottom: 100%;
|
||||||
|
flex-direction: column-reverse;
|
||||||
|
}
|
||||||
|
.keybinding-tag[data-v-2ec1b620] {
|
||||||
|
background: var(--p-content-hover-background);
|
||||||
|
border-color: var(--p-content-border-color);
|
||||||
|
border-style: solid;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-713442be] .p-inputtext {
|
||||||
|
border-top-left-radius: 0;
|
||||||
|
border-bottom-left-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfyui-queue-button[data-v-fcd3efcd] .p-splitbutton-dropdown {
|
||||||
|
border-top-right-radius: 0;
|
||||||
|
border-bottom-right-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.actionbar[data-v-bc6c78dd] {
|
||||||
|
pointer-events: all;
|
||||||
|
position: fixed;
|
||||||
|
z-index: 1000;
|
||||||
|
}
|
||||||
|
.actionbar.is-docked[data-v-bc6c78dd] {
|
||||||
|
position: static;
|
||||||
|
border-style: none;
|
||||||
|
background-color: transparent;
|
||||||
|
padding: 0px;
|
||||||
|
}
|
||||||
|
.actionbar.is-dragging[data-v-bc6c78dd] {
|
||||||
|
-webkit-user-select: none;
|
||||||
|
-moz-user-select: none;
|
||||||
|
user-select: none;
|
||||||
|
}
|
||||||
|
[data-v-bc6c78dd] .p-panel-content {
|
||||||
|
padding: 0.25rem;
|
||||||
|
}
|
||||||
|
[data-v-bc6c78dd] .p-panel-header {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfyui-menu[data-v-b13fdc92] {
|
||||||
|
width: 100vw;
|
||||||
|
background: var(--comfy-menu-bg);
|
||||||
|
color: var(--fg-color);
|
||||||
|
font-family: Arial, Helvetica, sans-serif;
|
||||||
|
font-size: 0.8em;
|
||||||
|
box-sizing: border-box;
|
||||||
|
z-index: 1000;
|
||||||
|
order: 0;
|
||||||
|
grid-column: 1/-1;
|
||||||
|
max-height: 90vh;
|
||||||
|
}
|
||||||
|
.comfyui-menu.dropzone[data-v-b13fdc92] {
|
||||||
|
background: var(--p-highlight-background);
|
||||||
|
}
|
||||||
|
.comfyui-menu.dropzone-active[data-v-b13fdc92] {
|
||||||
|
background: var(--p-highlight-background-focus);
|
||||||
|
}
|
||||||
|
.comfyui-logo[data-v-b13fdc92] {
|
||||||
|
font-size: 1.2em;
|
||||||
|
-webkit-user-select: none;
|
||||||
|
-moz-user-select: none;
|
||||||
|
user-select: none;
|
||||||
|
cursor: default;
|
||||||
|
}
|
||||||
17465
web/assets/GraphView-CVV2XJjS.js
generated
vendored
Normal file
17465
web/assets/GraphView-CVV2XJjS.js
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/GraphView-CVV2XJjS.js.map
generated
vendored
Normal file
1
web/assets/GraphView-CVV2XJjS.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
865
web/assets/colorPalette-D5oi2-2V.js
generated
vendored
Normal file
865
web/assets/colorPalette-D5oi2-2V.js
generated
vendored
Normal file
@@ -0,0 +1,865 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { k as app, aP as LGraphCanvas, bO as useToastStore, ca as $el, z as LiteGraph } from "./index-DGAbdBYF.js";
|
||||||
|
const colorPalettes = {
|
||||||
|
dark: {
|
||||||
|
id: "dark",
|
||||||
|
name: "Dark (Default)",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
CLIP: "#FFD500",
|
||||||
|
// bright yellow
|
||||||
|
CLIP_VISION: "#A8DADC",
|
||||||
|
// light blue-gray
|
||||||
|
CLIP_VISION_OUTPUT: "#ad7452",
|
||||||
|
// rusty brown-orange
|
||||||
|
CONDITIONING: "#FFA931",
|
||||||
|
// vibrant orange-yellow
|
||||||
|
CONTROL_NET: "#6EE7B7",
|
||||||
|
// soft mint green
|
||||||
|
IMAGE: "#64B5F6",
|
||||||
|
// bright sky blue
|
||||||
|
LATENT: "#FF9CF9",
|
||||||
|
// light pink-purple
|
||||||
|
MASK: "#81C784",
|
||||||
|
// muted green
|
||||||
|
MODEL: "#B39DDB",
|
||||||
|
// light lavender-purple
|
||||||
|
STYLE_MODEL: "#C2FFAE",
|
||||||
|
// light green-yellow
|
||||||
|
VAE: "#FF6E6E",
|
||||||
|
// bright red
|
||||||
|
NOISE: "#B0B0B0",
|
||||||
|
// gray
|
||||||
|
GUIDER: "#66FFFF",
|
||||||
|
// cyan
|
||||||
|
SAMPLER: "#ECB4B4",
|
||||||
|
// very soft red
|
||||||
|
SIGMAS: "#CDFFCD",
|
||||||
|
// soft lime green
|
||||||
|
TAESD: "#DCC274"
|
||||||
|
// cheesecake
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "#222",
|
||||||
|
NODE_TITLE_COLOR: "#999",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#FFF",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#AAA",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#333",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#353535",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#666",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#FFF",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#222",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#666",
|
||||||
|
WIDGET_TEXT_COLOR: "#DDD",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||||
|
LINK_COLOR: "#9A9",
|
||||||
|
EVENT_LINK_COLOR: "#A86",
|
||||||
|
CONNECTING_LINK_COLOR: "#AFA",
|
||||||
|
BADGE_FG_COLOR: "#FFF",
|
||||||
|
BADGE_BG_COLOR: "#0F1F0F"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#fff",
|
||||||
|
"bg-color": "#202020",
|
||||||
|
"comfy-menu-bg": "#353535",
|
||||||
|
"comfy-input-bg": "#222",
|
||||||
|
"input-text": "#ddd",
|
||||||
|
"descrip-text": "#999",
|
||||||
|
"drag-text": "#ccc",
|
||||||
|
"error-text": "#ff4444",
|
||||||
|
"border-color": "#4e4e4e",
|
||||||
|
"tr-even-bg-color": "#222",
|
||||||
|
"tr-odd-bg-color": "#353535",
|
||||||
|
"content-bg": "#4e4e4e",
|
||||||
|
"content-fg": "#fff",
|
||||||
|
"content-hover-bg": "#222",
|
||||||
|
"content-hover-fg": "#fff"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
light: {
|
||||||
|
id: "light",
|
||||||
|
name: "Light",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
CLIP: "#FFA726",
|
||||||
|
// orange
|
||||||
|
CLIP_VISION: "#5C6BC0",
|
||||||
|
// indigo
|
||||||
|
CLIP_VISION_OUTPUT: "#8D6E63",
|
||||||
|
// brown
|
||||||
|
CONDITIONING: "#EF5350",
|
||||||
|
// red
|
||||||
|
CONTROL_NET: "#66BB6A",
|
||||||
|
// green
|
||||||
|
IMAGE: "#42A5F5",
|
||||||
|
// blue
|
||||||
|
LATENT: "#AB47BC",
|
||||||
|
// purple
|
||||||
|
MASK: "#9CCC65",
|
||||||
|
// light green
|
||||||
|
MODEL: "#7E57C2",
|
||||||
|
// deep purple
|
||||||
|
STYLE_MODEL: "#D4E157",
|
||||||
|
// lime
|
||||||
|
VAE: "#FF7043"
|
||||||
|
// deep orange
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "lightgray",
|
||||||
|
NODE_TITLE_COLOR: "#222",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#000",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#444",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#F7F7F7",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#F5F5F5",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#CCC",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#000",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.1)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#D4D4D4",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#999",
|
||||||
|
WIDGET_TEXT_COLOR: "#222",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#555",
|
||||||
|
LINK_COLOR: "#4CAF50",
|
||||||
|
EVENT_LINK_COLOR: "#FF9800",
|
||||||
|
CONNECTING_LINK_COLOR: "#2196F3",
|
||||||
|
BADGE_FG_COLOR: "#000",
|
||||||
|
BADGE_BG_COLOR: "#FFF"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#222",
|
||||||
|
"bg-color": "#DDD",
|
||||||
|
"comfy-menu-bg": "#F5F5F5",
|
||||||
|
"comfy-input-bg": "#C9C9C9",
|
||||||
|
"input-text": "#222",
|
||||||
|
"descrip-text": "#444",
|
||||||
|
"drag-text": "#555",
|
||||||
|
"error-text": "#F44336",
|
||||||
|
"border-color": "#888",
|
||||||
|
"tr-even-bg-color": "#f9f9f9",
|
||||||
|
"tr-odd-bg-color": "#fff",
|
||||||
|
"content-bg": "#e0e0e0",
|
||||||
|
"content-fg": "#222",
|
||||||
|
"content-hover-bg": "#adadad",
|
||||||
|
"content-hover-fg": "#222"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
solarized: {
|
||||||
|
id: "solarized",
|
||||||
|
name: "Solarized",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
CLIP: "#2AB7CA",
|
||||||
|
// light blue
|
||||||
|
CLIP_VISION: "#6c71c4",
|
||||||
|
// blue violet
|
||||||
|
CLIP_VISION_OUTPUT: "#859900",
|
||||||
|
// olive green
|
||||||
|
CONDITIONING: "#d33682",
|
||||||
|
// magenta
|
||||||
|
CONTROL_NET: "#d1ffd7",
|
||||||
|
// light mint green
|
||||||
|
IMAGE: "#5940bb",
|
||||||
|
// deep blue violet
|
||||||
|
LATENT: "#268bd2",
|
||||||
|
// blue
|
||||||
|
MASK: "#CCC9E7",
|
||||||
|
// light purple-gray
|
||||||
|
MODEL: "#dc322f",
|
||||||
|
// red
|
||||||
|
STYLE_MODEL: "#1a998a",
|
||||||
|
// teal
|
||||||
|
UPSCALE_MODEL: "#054A29",
|
||||||
|
// dark green
|
||||||
|
VAE: "#facfad"
|
||||||
|
// light pink-orange
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
NODE_TITLE_COLOR: "#fdf6e3",
|
||||||
|
// Base3
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#A9D400",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#657b83",
|
||||||
|
// Base00
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#094656",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#073642",
|
||||||
|
// Base02
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#839496",
|
||||||
|
// Base0
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#fdf6e3",
|
||||||
|
// Base3
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#002b36",
|
||||||
|
// Base03
|
||||||
|
WIDGET_OUTLINE_COLOR: "#839496",
|
||||||
|
// Base0
|
||||||
|
WIDGET_TEXT_COLOR: "#fdf6e3",
|
||||||
|
// Base3
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#93a1a1",
|
||||||
|
// Base1
|
||||||
|
LINK_COLOR: "#2aa198",
|
||||||
|
// Solarized Cyan
|
||||||
|
EVENT_LINK_COLOR: "#268bd2",
|
||||||
|
// Solarized Blue
|
||||||
|
CONNECTING_LINK_COLOR: "#859900"
|
||||||
|
// Solarized Green
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#fdf6e3",
|
||||||
|
// Base3
|
||||||
|
"bg-color": "#002b36",
|
||||||
|
// Base03
|
||||||
|
"comfy-menu-bg": "#073642",
|
||||||
|
// Base02
|
||||||
|
"comfy-input-bg": "#002b36",
|
||||||
|
// Base03
|
||||||
|
"input-text": "#93a1a1",
|
||||||
|
// Base1
|
||||||
|
"descrip-text": "#586e75",
|
||||||
|
// Base01
|
||||||
|
"drag-text": "#839496",
|
||||||
|
// Base0
|
||||||
|
"error-text": "#dc322f",
|
||||||
|
// Solarized Red
|
||||||
|
"border-color": "#657b83",
|
||||||
|
// Base00
|
||||||
|
"tr-even-bg-color": "#002b36",
|
||||||
|
"tr-odd-bg-color": "#073642",
|
||||||
|
"content-bg": "#657b83",
|
||||||
|
"content-fg": "#fdf6e3",
|
||||||
|
"content-hover-bg": "#002b36",
|
||||||
|
"content-hover-fg": "#fdf6e3"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
arc: {
|
||||||
|
id: "arc",
|
||||||
|
name: "Arc",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
BOOLEAN: "",
|
||||||
|
CLIP: "#eacb8b",
|
||||||
|
CLIP_VISION: "#A8DADC",
|
||||||
|
CLIP_VISION_OUTPUT: "#ad7452",
|
||||||
|
CONDITIONING: "#cf876f",
|
||||||
|
CONTROL_NET: "#00d78d",
|
||||||
|
CONTROL_NET_WEIGHTS: "",
|
||||||
|
FLOAT: "",
|
||||||
|
GLIGEN: "",
|
||||||
|
IMAGE: "#80a1c0",
|
||||||
|
IMAGEUPLOAD: "",
|
||||||
|
INT: "",
|
||||||
|
LATENT: "#b38ead",
|
||||||
|
LATENT_KEYFRAME: "",
|
||||||
|
MASK: "#a3bd8d",
|
||||||
|
MODEL: "#8978a7",
|
||||||
|
SAMPLER: "",
|
||||||
|
SIGMAS: "",
|
||||||
|
STRING: "",
|
||||||
|
STYLE_MODEL: "#C2FFAE",
|
||||||
|
T2I_ADAPTER_WEIGHTS: "",
|
||||||
|
TAESD: "#DCC274",
|
||||||
|
TIMESTEP_KEYFRAME: "",
|
||||||
|
UPSCALE_MODEL: "",
|
||||||
|
VAE: "#be616b"
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "#2b2f38",
|
||||||
|
NODE_TITLE_COLOR: "#b2b7bd",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#FFF",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#AAA",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#2b2f38",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#242730",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#6e7581",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#FFF",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 22,
|
||||||
|
WIDGET_BGCOLOR: "#2b2f38",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#6e7581",
|
||||||
|
WIDGET_TEXT_COLOR: "#DDD",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#b2b7bd",
|
||||||
|
LINK_COLOR: "#9A9",
|
||||||
|
EVENT_LINK_COLOR: "#A86",
|
||||||
|
CONNECTING_LINK_COLOR: "#AFA"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#fff",
|
||||||
|
"bg-color": "#2b2f38",
|
||||||
|
"comfy-menu-bg": "#242730",
|
||||||
|
"comfy-input-bg": "#2b2f38",
|
||||||
|
"input-text": "#ddd",
|
||||||
|
"descrip-text": "#b2b7bd",
|
||||||
|
"drag-text": "#ccc",
|
||||||
|
"error-text": "#ff4444",
|
||||||
|
"border-color": "#6e7581",
|
||||||
|
"tr-even-bg-color": "#2b2f38",
|
||||||
|
"tr-odd-bg-color": "#242730",
|
||||||
|
"content-bg": "#6e7581",
|
||||||
|
"content-fg": "#fff",
|
||||||
|
"content-hover-bg": "#2b2f38",
|
||||||
|
"content-hover-fg": "#fff"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nord: {
|
||||||
|
id: "nord",
|
||||||
|
name: "Nord",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
BOOLEAN: "",
|
||||||
|
CLIP: "#eacb8b",
|
||||||
|
CLIP_VISION: "#A8DADC",
|
||||||
|
CLIP_VISION_OUTPUT: "#ad7452",
|
||||||
|
CONDITIONING: "#cf876f",
|
||||||
|
CONTROL_NET: "#00d78d",
|
||||||
|
CONTROL_NET_WEIGHTS: "",
|
||||||
|
FLOAT: "",
|
||||||
|
GLIGEN: "",
|
||||||
|
IMAGE: "#80a1c0",
|
||||||
|
IMAGEUPLOAD: "",
|
||||||
|
INT: "",
|
||||||
|
LATENT: "#b38ead",
|
||||||
|
LATENT_KEYFRAME: "",
|
||||||
|
MASK: "#a3bd8d",
|
||||||
|
MODEL: "#8978a7",
|
||||||
|
SAMPLER: "",
|
||||||
|
SIGMAS: "",
|
||||||
|
STRING: "",
|
||||||
|
STYLE_MODEL: "#C2FFAE",
|
||||||
|
T2I_ADAPTER_WEIGHTS: "",
|
||||||
|
TAESD: "#DCC274",
|
||||||
|
TIMESTEP_KEYFRAME: "",
|
||||||
|
UPSCALE_MODEL: "",
|
||||||
|
VAE: "#be616b"
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "#212732",
|
||||||
|
NODE_TITLE_COLOR: "#999",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#bcc2c8",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#2e3440",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#161b22",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#545d70",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#2e3440",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#545d70",
|
||||||
|
WIDGET_TEXT_COLOR: "#bcc2c8",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||||
|
LINK_COLOR: "#9A9",
|
||||||
|
EVENT_LINK_COLOR: "#A86",
|
||||||
|
CONNECTING_LINK_COLOR: "#AFA"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#e5eaf0",
|
||||||
|
"bg-color": "#2e3440",
|
||||||
|
"comfy-menu-bg": "#161b22",
|
||||||
|
"comfy-input-bg": "#2e3440",
|
||||||
|
"input-text": "#bcc2c8",
|
||||||
|
"descrip-text": "#999",
|
||||||
|
"drag-text": "#ccc",
|
||||||
|
"error-text": "#ff4444",
|
||||||
|
"border-color": "#545d70",
|
||||||
|
"tr-even-bg-color": "#2e3440",
|
||||||
|
"tr-odd-bg-color": "#161b22",
|
||||||
|
"content-bg": "#545d70",
|
||||||
|
"content-fg": "#e5eaf0",
|
||||||
|
"content-hover-bg": "#2e3440",
|
||||||
|
"content-hover-fg": "#e5eaf0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
github: {
|
||||||
|
id: "github",
|
||||||
|
name: "Github",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
BOOLEAN: "",
|
||||||
|
CLIP: "#eacb8b",
|
||||||
|
CLIP_VISION: "#A8DADC",
|
||||||
|
CLIP_VISION_OUTPUT: "#ad7452",
|
||||||
|
CONDITIONING: "#cf876f",
|
||||||
|
CONTROL_NET: "#00d78d",
|
||||||
|
CONTROL_NET_WEIGHTS: "",
|
||||||
|
FLOAT: "",
|
||||||
|
GLIGEN: "",
|
||||||
|
IMAGE: "#80a1c0",
|
||||||
|
IMAGEUPLOAD: "",
|
||||||
|
INT: "",
|
||||||
|
LATENT: "#b38ead",
|
||||||
|
LATENT_KEYFRAME: "",
|
||||||
|
MASK: "#a3bd8d",
|
||||||
|
MODEL: "#8978a7",
|
||||||
|
SAMPLER: "",
|
||||||
|
SIGMAS: "",
|
||||||
|
STRING: "",
|
||||||
|
STYLE_MODEL: "#C2FFAE",
|
||||||
|
T2I_ADAPTER_WEIGHTS: "",
|
||||||
|
TAESD: "#DCC274",
|
||||||
|
TIMESTEP_KEYFRAME: "",
|
||||||
|
UPSCALE_MODEL: "",
|
||||||
|
VAE: "#be616b"
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "#040506",
|
||||||
|
NODE_TITLE_COLOR: "#999",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#bcc2c8",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#161b22",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#13171d",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#30363d",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#161b22",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#30363d",
|
||||||
|
WIDGET_TEXT_COLOR: "#bcc2c8",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||||
|
LINK_COLOR: "#9A9",
|
||||||
|
EVENT_LINK_COLOR: "#A86",
|
||||||
|
CONNECTING_LINK_COLOR: "#AFA"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#e5eaf0",
|
||||||
|
"bg-color": "#161b22",
|
||||||
|
"comfy-menu-bg": "#13171d",
|
||||||
|
"comfy-input-bg": "#161b22",
|
||||||
|
"input-text": "#bcc2c8",
|
||||||
|
"descrip-text": "#999",
|
||||||
|
"drag-text": "#ccc",
|
||||||
|
"error-text": "#ff4444",
|
||||||
|
"border-color": "#30363d",
|
||||||
|
"tr-even-bg-color": "#161b22",
|
||||||
|
"tr-odd-bg-color": "#13171d",
|
||||||
|
"content-bg": "#30363d",
|
||||||
|
"content-fg": "#e5eaf0",
|
||||||
|
"content-hover-bg": "#161b22",
|
||||||
|
"content-hover-fg": "#e5eaf0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const id = "Comfy.ColorPalette";
|
||||||
|
const idCustomColorPalettes = "Comfy.CustomColorPalettes";
|
||||||
|
const defaultColorPaletteId = "dark";
|
||||||
|
const els = {
|
||||||
|
select: null
|
||||||
|
};
|
||||||
|
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
|
||||||
|
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
|
||||||
|
}, "getCustomColorPalettes");
|
||||||
|
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
|
||||||
|
return app.ui.settings.setSettingValue(
|
||||||
|
idCustomColorPalettes,
|
||||||
|
customColorPalettes
|
||||||
|
);
|
||||||
|
}, "setCustomColorPalettes");
|
||||||
|
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
|
||||||
|
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
|
||||||
|
if (!colorPaletteId) {
|
||||||
|
colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
|
||||||
|
}
|
||||||
|
if (colorPaletteId.startsWith("custom_")) {
|
||||||
|
colorPaletteId = colorPaletteId.substr(7);
|
||||||
|
let customColorPalettes = getCustomColorPalettes();
|
||||||
|
if (customColorPalettes[colorPaletteId]) {
|
||||||
|
return customColorPalettes[colorPaletteId];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return colorPalettes[colorPaletteId];
|
||||||
|
}, "getColorPalette");
|
||||||
|
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
|
||||||
|
app.ui.settings.setSettingValue(id, colorPaletteId);
|
||||||
|
}, "setColorPalette");
|
||||||
|
app.registerExtension({
|
||||||
|
name: id,
|
||||||
|
init() {
|
||||||
|
LGraphCanvas.prototype.updateBackground = function(image, clearBackgroundColor) {
|
||||||
|
this._bg_img = new Image();
|
||||||
|
this._bg_img.name = image;
|
||||||
|
this._bg_img.src = image;
|
||||||
|
this._bg_img.onload = () => {
|
||||||
|
this.draw(true, true);
|
||||||
|
};
|
||||||
|
this.background_image = image;
|
||||||
|
this.clear_background = true;
|
||||||
|
this.clear_background_color = clearBackgroundColor;
|
||||||
|
this._pattern = null;
|
||||||
|
};
|
||||||
|
},
|
||||||
|
addCustomNodeDefs(node_defs) {
|
||||||
|
const sortObjectKeys = /* @__PURE__ */ __name((unordered) => {
|
||||||
|
return Object.keys(unordered).sort().reduce((obj, key) => {
|
||||||
|
obj[key] = unordered[key];
|
||||||
|
return obj;
|
||||||
|
}, {});
|
||||||
|
}, "sortObjectKeys");
|
||||||
|
function getSlotTypes() {
|
||||||
|
var types = [];
|
||||||
|
const defs = node_defs;
|
||||||
|
for (const nodeId in defs) {
|
||||||
|
const nodeData = defs[nodeId];
|
||||||
|
var inputs = nodeData["input"]["required"];
|
||||||
|
if (nodeData["input"]["optional"] !== void 0) {
|
||||||
|
inputs = Object.assign(
|
||||||
|
{},
|
||||||
|
nodeData["input"]["required"],
|
||||||
|
nodeData["input"]["optional"]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for (const inputName in inputs) {
|
||||||
|
const inputData = inputs[inputName];
|
||||||
|
const type = inputData[0];
|
||||||
|
if (!Array.isArray(type)) {
|
||||||
|
types.push(type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const o in nodeData["output"]) {
|
||||||
|
const output = nodeData["output"][o];
|
||||||
|
types.push(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return types;
|
||||||
|
}
|
||||||
|
__name(getSlotTypes, "getSlotTypes");
|
||||||
|
function completeColorPalette(colorPalette) {
|
||||||
|
var types = getSlotTypes();
|
||||||
|
for (const type of types) {
|
||||||
|
if (!colorPalette.colors.node_slot[type]) {
|
||||||
|
colorPalette.colors.node_slot[type] = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
colorPalette.colors.node_slot = sortObjectKeys(
|
||||||
|
colorPalette.colors.node_slot
|
||||||
|
);
|
||||||
|
return colorPalette;
|
||||||
|
}
|
||||||
|
__name(completeColorPalette, "completeColorPalette");
|
||||||
|
const getColorPaletteTemplate = /* @__PURE__ */ __name(async () => {
|
||||||
|
let colorPalette = {
|
||||||
|
id: "my_color_palette_unique_id",
|
||||||
|
name: "My Color Palette",
|
||||||
|
colors: {
|
||||||
|
node_slot: {},
|
||||||
|
litegraph_base: {},
|
||||||
|
comfy_base: {}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const defaultColorPalette2 = colorPalettes[defaultColorPaletteId];
|
||||||
|
for (const key in defaultColorPalette2.colors.litegraph_base) {
|
||||||
|
if (!colorPalette.colors.litegraph_base[key]) {
|
||||||
|
colorPalette.colors.litegraph_base[key] = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const key in defaultColorPalette2.colors.comfy_base) {
|
||||||
|
if (!colorPalette.colors.comfy_base[key]) {
|
||||||
|
colorPalette.colors.comfy_base[key] = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return completeColorPalette(colorPalette);
|
||||||
|
}, "getColorPaletteTemplate");
|
||||||
|
const addCustomColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
|
||||||
|
if (typeof colorPalette !== "object") {
|
||||||
|
useToastStore().addAlert("Invalid color palette.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!colorPalette.id) {
|
||||||
|
useToastStore().addAlert("Color palette missing id.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!colorPalette.name) {
|
||||||
|
useToastStore().addAlert("Color palette missing name.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!colorPalette.colors) {
|
||||||
|
useToastStore().addAlert("Color palette missing colors.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (colorPalette.colors.node_slot && typeof colorPalette.colors.node_slot !== "object") {
|
||||||
|
useToastStore().addAlert("Invalid color palette colors.node_slot.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const customColorPalettes = getCustomColorPalettes();
|
||||||
|
customColorPalettes[colorPalette.id] = colorPalette;
|
||||||
|
setCustomColorPalettes(customColorPalettes);
|
||||||
|
for (const option of els.select.childNodes) {
|
||||||
|
if (option.value === "custom_" + colorPalette.id) {
|
||||||
|
els.select.removeChild(option);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
els.select.append(
|
||||||
|
$el("option", {
|
||||||
|
textContent: colorPalette.name + " (custom)",
|
||||||
|
value: "custom_" + colorPalette.id,
|
||||||
|
selected: true
|
||||||
|
})
|
||||||
|
);
|
||||||
|
setColorPalette("custom_" + colorPalette.id);
|
||||||
|
await loadColorPalette(colorPalette);
|
||||||
|
}, "addCustomColorPalette");
|
||||||
|
const deleteCustomColorPalette = /* @__PURE__ */ __name(async (colorPaletteId) => {
|
||||||
|
const customColorPalettes = getCustomColorPalettes();
|
||||||
|
delete customColorPalettes[colorPaletteId];
|
||||||
|
setCustomColorPalettes(customColorPalettes);
|
||||||
|
for (const opt of els.select.childNodes) {
|
||||||
|
const option = opt;
|
||||||
|
if (option.value === defaultColorPaletteId) {
|
||||||
|
option.selected = true;
|
||||||
|
}
|
||||||
|
if (option.value === "custom_" + colorPaletteId) {
|
||||||
|
els.select.removeChild(option);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setColorPalette(defaultColorPaletteId);
|
||||||
|
await loadColorPalette(getColorPalette());
|
||||||
|
}, "deleteCustomColorPalette");
|
||||||
|
const loadColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
|
||||||
|
colorPalette = await completeColorPalette(colorPalette);
|
||||||
|
if (colorPalette.colors) {
|
||||||
|
if (colorPalette.colors.node_slot) {
|
||||||
|
Object.assign(
|
||||||
|
app.canvas.default_connection_color_byType,
|
||||||
|
colorPalette.colors.node_slot
|
||||||
|
);
|
||||||
|
Object.assign(
|
||||||
|
LGraphCanvas.link_type_colors,
|
||||||
|
colorPalette.colors.node_slot
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (colorPalette.colors.litegraph_base) {
|
||||||
|
app.canvas.node_title_color = colorPalette.colors.litegraph_base.NODE_TITLE_COLOR;
|
||||||
|
app.canvas.default_link_color = colorPalette.colors.litegraph_base.LINK_COLOR;
|
||||||
|
for (const key in colorPalette.colors.litegraph_base) {
|
||||||
|
if (colorPalette.colors.litegraph_base.hasOwnProperty(key) && LiteGraph.hasOwnProperty(key)) {
|
||||||
|
LiteGraph[key] = colorPalette.colors.litegraph_base[key];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (colorPalette.colors.comfy_base) {
|
||||||
|
const rootStyle = document.documentElement.style;
|
||||||
|
for (const key in colorPalette.colors.comfy_base) {
|
||||||
|
rootStyle.setProperty(
|
||||||
|
"--" + key,
|
||||||
|
colorPalette.colors.comfy_base[key]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR) {
|
||||||
|
app.bypassBgColor = colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR;
|
||||||
|
}
|
||||||
|
app.canvas.draw(true, true);
|
||||||
|
}
|
||||||
|
}, "loadColorPalette");
|
||||||
|
const fileInput = $el("input", {
|
||||||
|
type: "file",
|
||||||
|
accept: ".json",
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body,
|
||||||
|
onchange: /* @__PURE__ */ __name(() => {
|
||||||
|
const file = fileInput.files[0];
|
||||||
|
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = async () => {
|
||||||
|
await addCustomColorPalette(JSON.parse(reader.result));
|
||||||
|
};
|
||||||
|
reader.readAsText(file);
|
||||||
|
}
|
||||||
|
}, "onchange")
|
||||||
|
});
|
||||||
|
app.ui.settings.addSetting({
|
||||||
|
id,
|
||||||
|
category: ["Comfy", "ColorPalette"],
|
||||||
|
name: "Color Palette",
|
||||||
|
type: /* @__PURE__ */ __name((name, setter, value) => {
|
||||||
|
const options = [
|
||||||
|
...Object.values(colorPalettes).map(
|
||||||
|
(c) => $el("option", {
|
||||||
|
textContent: c.name,
|
||||||
|
value: c.id,
|
||||||
|
selected: c.id === value
|
||||||
|
})
|
||||||
|
),
|
||||||
|
...Object.values(getCustomColorPalettes()).map(
|
||||||
|
(c) => $el("option", {
|
||||||
|
textContent: `${c.name} (custom)`,
|
||||||
|
value: `custom_${c.id}`,
|
||||||
|
selected: `custom_${c.id}` === value
|
||||||
|
})
|
||||||
|
)
|
||||||
|
];
|
||||||
|
els.select = $el(
|
||||||
|
"select",
|
||||||
|
{
|
||||||
|
style: {
|
||||||
|
marginBottom: "0.15rem",
|
||||||
|
width: "100%"
|
||||||
|
},
|
||||||
|
onchange: /* @__PURE__ */ __name((e) => {
|
||||||
|
setter(e.target.value);
|
||||||
|
}, "onchange")
|
||||||
|
},
|
||||||
|
options
|
||||||
|
);
|
||||||
|
return $el("tr", [
|
||||||
|
$el("td", [
|
||||||
|
els.select,
|
||||||
|
$el(
|
||||||
|
"div",
|
||||||
|
{
|
||||||
|
style: {
|
||||||
|
display: "grid",
|
||||||
|
gap: "4px",
|
||||||
|
gridAutoFlow: "column"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Export",
|
||||||
|
onclick: /* @__PURE__ */ __name(async () => {
|
||||||
|
const colorPaletteId = app.ui.settings.getSettingValue(
|
||||||
|
id,
|
||||||
|
defaultColorPaletteId
|
||||||
|
);
|
||||||
|
const colorPalette = await completeColorPalette(
|
||||||
|
getColorPalette(colorPaletteId)
|
||||||
|
);
|
||||||
|
const json = JSON.stringify(colorPalette, null, 2);
|
||||||
|
const blob = new Blob([json], { type: "application/json" });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: colorPaletteId + ".json",
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function() {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
}, "onclick")
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Import",
|
||||||
|
onclick: /* @__PURE__ */ __name(() => {
|
||||||
|
fileInput.click();
|
||||||
|
}, "onclick")
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Template",
|
||||||
|
onclick: /* @__PURE__ */ __name(async () => {
|
||||||
|
const colorPalette = await getColorPaletteTemplate();
|
||||||
|
const json = JSON.stringify(colorPalette, null, 2);
|
||||||
|
const blob = new Blob([json], { type: "application/json" });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: "color_palette.json",
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function() {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
}, "onclick")
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Delete",
|
||||||
|
onclick: /* @__PURE__ */ __name(async () => {
|
||||||
|
let colorPaletteId = app.ui.settings.getSettingValue(
|
||||||
|
id,
|
||||||
|
defaultColorPaletteId
|
||||||
|
);
|
||||||
|
if (colorPalettes[colorPaletteId]) {
|
||||||
|
useToastStore().addAlert(
|
||||||
|
"You cannot delete a built-in color palette."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (colorPaletteId.startsWith("custom_")) {
|
||||||
|
colorPaletteId = colorPaletteId.substr(7);
|
||||||
|
}
|
||||||
|
await deleteCustomColorPalette(colorPaletteId);
|
||||||
|
}, "onclick")
|
||||||
|
})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
]);
|
||||||
|
}, "type"),
|
||||||
|
defaultValue: defaultColorPaletteId,
|
||||||
|
async onChange(value) {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let palette = colorPalettes[value];
|
||||||
|
if (palette) {
|
||||||
|
await loadColorPalette(palette);
|
||||||
|
} else if (value.startsWith("custom_")) {
|
||||||
|
value = value.substr(7);
|
||||||
|
let customColorPalettes = getCustomColorPalettes();
|
||||||
|
if (customColorPalettes[value]) {
|
||||||
|
palette = customColorPalettes[value];
|
||||||
|
await loadColorPalette(customColorPalettes[value]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let { BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR } = palette.colors.litegraph_base;
|
||||||
|
if (BACKGROUND_IMAGE === void 0 || CLEAR_BACKGROUND_COLOR === void 0) {
|
||||||
|
const base = colorPalettes["dark"].colors.litegraph_base;
|
||||||
|
BACKGROUND_IMAGE = base.BACKGROUND_IMAGE;
|
||||||
|
CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR;
|
||||||
|
}
|
||||||
|
app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
window.comfyAPI = window.comfyAPI || {};
|
||||||
|
window.comfyAPI.colorPalette = window.comfyAPI.colorPalette || {};
|
||||||
|
window.comfyAPI.colorPalette.defaultColorPalette = defaultColorPalette;
|
||||||
|
window.comfyAPI.colorPalette.getColorPalette = getColorPalette;
|
||||||
|
export {
|
||||||
|
defaultColorPalette as d,
|
||||||
|
getColorPalette as g
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=colorPalette-D5oi2-2V.js.map
|
||||||
1
web/assets/colorPalette-D5oi2-2V.js.map
generated
vendored
Normal file
1
web/assets/colorPalette-D5oi2-2V.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
5372
web/assets/index-DAK31IJJ.css → web/assets/index-BHJGjcJh.css
generated
vendored
5372
web/assets/index-DAK31IJJ.css → web/assets/index-BHJGjcJh.css
generated
vendored
File diff suppressed because it is too large
Load Diff
3716
web/assets/index-DkvOTKox.js → web/assets/index-BMC1ey-i.js
generated
vendored
3716
web/assets/index-DkvOTKox.js → web/assets/index-BMC1ey-i.js
generated
vendored
File diff suppressed because it is too large
Load Diff
1
web/assets/index-BMC1ey-i.js.map
generated
vendored
Normal file
1
web/assets/index-BMC1ey-i.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
0
web/assets/index-DjWyclij.css → web/assets/index-BRhY6FpL.css
generated
vendored
0
web/assets/index-DjWyclij.css → web/assets/index-BRhY6FpL.css
generated
vendored
1
web/assets/index-CaD4RONs.js.map
generated
vendored
1
web/assets/index-CaD4RONs.js.map
generated
vendored
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user