Compare commits
167 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
696672905f | ||
|
|
6c9dbde7de | ||
|
|
ee8abf0cff | ||
|
|
fabf449feb | ||
|
|
cc9cf6d1bd | ||
|
|
1c8286a44b | ||
|
|
1af4a47fd1 | ||
|
|
f2aaa0a475 | ||
|
|
daa1565b93 | ||
|
|
09fdb2b269 | ||
|
|
65a8659182 | ||
|
|
770ab200f2 | ||
|
|
954683d0db | ||
|
|
30c0c81351 | ||
|
|
13b0ff8a6f | ||
|
|
c320801187 | ||
|
|
c0b0cfaeec | ||
|
|
669d9e4c67 | ||
|
|
9ee0a6553a | ||
|
|
5cbb01bc2f | ||
|
|
c3ffbae067 | ||
|
|
d605677b33 | ||
|
|
ce759b7db6 | ||
|
|
52810907e2 | ||
|
|
af8cf79a2d | ||
|
|
66b0961a46 | ||
|
|
754597c8a9 | ||
|
|
915fdb5745 | ||
|
|
5a8a48931a | ||
|
|
8ce2a1052c | ||
|
|
f82314fcfc | ||
|
|
0075c6d096 | ||
|
|
83ca891118 | ||
|
|
f9f9faface | ||
|
|
471cd3eace | ||
|
|
a68bbafddb | ||
|
|
73e3a9e676 | ||
|
|
518c0dc2fe | ||
|
|
ce0542e10b | ||
|
|
8473019d40 | ||
|
|
89f15894dd | ||
|
|
67158994a4 | ||
|
|
7390ff3b1e | ||
|
|
0bedfb26af | ||
|
|
f71cfd2687 | ||
|
|
c695c4af7f | ||
|
|
0dbba9f751 | ||
|
|
f584758271 | ||
|
|
95b7cf9bbe | ||
|
|
191a0d56b4 | ||
|
|
3c60ecd7a8 | ||
|
|
7ae6626723 | ||
|
|
6632365e16 | ||
|
|
ad07796777 | ||
|
|
1b80895285 | ||
|
|
5f9d5a244b | ||
|
|
14eba07acd | ||
|
|
4b2f0d9413 | ||
|
|
25eac1d780 | ||
|
|
e38c94228b | ||
|
|
203942c8b2 | ||
|
|
3c72c89a52 | ||
|
|
614377abd6 | ||
|
|
8dfa0cc552 | ||
|
|
e5ecdfdd2d | ||
|
|
7d29fbf74b | ||
|
|
2c641e64ad | ||
|
|
7d2467e830 | ||
|
|
6f021d8aa0 | ||
|
|
d854ed0bcf | ||
|
|
abcd006b8c | ||
|
|
d985d1d7dc | ||
|
|
d1cdf51e1b | ||
|
|
b4626ab93e | ||
|
|
a9e459c2a4 | ||
|
|
3bb4dec720 | ||
|
|
8733191563 | ||
|
|
83b01f960a | ||
|
|
d72e871cfa | ||
|
|
037c3159b6 | ||
|
|
bdd4a22a2e | ||
|
|
fdf37566ef | ||
|
|
08c8968482 | ||
|
|
479a427a48 | ||
|
|
3a0eeee320 | ||
|
|
447da7ea86 | ||
|
|
9c41bc8d10 | ||
|
|
6ad0ddbae4 | ||
|
|
a55142f904 | ||
|
|
5718ef69bb | ||
|
|
13ecf10a92 | ||
|
|
7a415f47a9 | ||
|
|
89fa2fca24 | ||
|
|
364b69e931 | ||
|
|
dc96a1ae19 | ||
|
|
2d810b081e | ||
|
|
9f7e9f0547 | ||
|
|
a355f38ecc | ||
|
|
38c69080c7 | ||
|
|
70a708d726 | ||
|
|
e7d4782736 | ||
|
|
3326bdfd4e | ||
|
|
68bb885d22 | ||
|
|
ad66f7c7d8 | ||
|
|
de8e8e3b0d | ||
|
|
a1e71cfad1 | ||
|
|
0bfc7cc998 | ||
|
|
7183fd1665 | ||
|
|
254838f23c | ||
|
|
0b7dfa986d | ||
|
|
d514bb38ee | ||
|
|
0849c80e2a | ||
|
|
56e8f5e4fd | ||
|
|
e813abbb2c | ||
|
|
5e68a4ce67 | ||
|
|
ca08597670 | ||
|
|
f48e390032 | ||
|
|
369a6dd2c4 | ||
|
|
b3ce8fb9fd | ||
|
|
cf80d28689 | ||
|
|
6fb44c4b7c | ||
|
|
d2247c1e61 | ||
|
|
cb12ad7049 | ||
|
|
f6b7194f64 | ||
|
|
7c6eb4fb29 | ||
|
|
b962db9952 | ||
|
|
d0b7ab88ba | ||
|
|
405b529545 | ||
|
|
9d720187f1 | ||
|
|
d247bc5a9c | ||
|
|
9f4daca9d9 | ||
|
|
b5d0f2a908 | ||
|
|
e760bf5c40 | ||
|
|
36c83cdbba | ||
|
|
81778a7feb | ||
|
|
bc94662b31 | ||
|
|
9fa8faa44a | ||
|
|
9a7444e39f | ||
|
|
54fca4a218 | ||
|
|
cd4955367e | ||
|
|
8354203d95 | ||
|
|
e0b41243b4 | ||
|
|
619263d4a6 | ||
|
|
e3b0402bb7 | ||
|
|
967867d48c | ||
|
|
cbaac71bf5 | ||
|
|
3ab3516e46 | ||
|
|
9c5fca75f4 | ||
|
|
a5da4d0b3e | ||
|
|
32a60a7bac | ||
|
|
bb52934ba4 | ||
|
|
8aabd7c8c0 | ||
|
|
a09b29ca11 | ||
|
|
9bfee68773 | ||
|
|
ea77750759 | ||
|
|
c27ebeb1c2 | ||
|
|
0c7c98a965 | ||
|
|
dc2eb75b85 | ||
|
|
fa34efe3bd | ||
|
|
5cbaa9e07c | ||
|
|
c7427375ee | ||
|
|
22d1241a50 | ||
|
|
f04229b84d | ||
|
|
f067ad15d1 | ||
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 |
2
.github/workflows/pullrequest-ci-run.yml
vendored
2
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
6
.github/workflows/stable-release.yml
vendored
6
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
|||||||
description: 'CUDA version'
|
description: 'CUDA version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "121"
|
default: "124"
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'Python minor version'
|
description: 'Python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'Python patch version'
|
description: 'Python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
4
.github/workflows/test-ci.yml
vendored
4
.github/workflows/test-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
torch_version: ["nightly"]
|
torch_version: ["nightly"]
|
||||||
include:
|
include:
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
30
.github/workflows/test-unit.yml
vendored
Normal file
30
.github/workflows/test-unit.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: Unit Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
continue-on-error: true
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Run Unit Tests
|
||||||
|
run: |
|
||||||
|
pip install -r tests-unit/requirements.txt
|
||||||
|
python -m pytest tests-unit
|
||||||
@@ -12,7 +12,7 @@ on:
|
|||||||
description: 'extra dependencies'
|
description: 'extra dependencies'
|
||||||
required: false
|
required: false
|
||||||
type: string
|
type: string
|
||||||
default: "\"numpy<2\""
|
default: ""
|
||||||
cu:
|
cu:
|
||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
@@ -23,13 +23,13 @@ on:
|
|||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ on:
|
|||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "11"
|
default: "12"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "7"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ extra_model_paths.yaml
|
|||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv/
|
venv/
|
||||||
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
!/web/extensions/core/
|
!/web/extensions/core/
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||||
@@ -94,6 +95,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| Alt + `+` | Canvas Zoom in |
|
| Alt + `+` | Canvas Zoom in |
|
||||||
| Alt + `-` | Canvas Zoom out |
|
| Alt + `-` | Canvas Zoom out |
|
||||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||||
|
| P | Pin/Unpin selected nodes |
|
||||||
|
| Ctrl + G | Group selected nodes |
|
||||||
| Q | Toggle visibility of the queue |
|
| Q | Toggle visibility of the queue |
|
||||||
| H | Toggle visibility of history |
|
| H | Toggle visibility of history |
|
||||||
| R | Refresh graph |
|
| R | Refresh graph |
|
||||||
@@ -125,6 +128,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
|
|||||||
|
|
||||||
## Manual Install (Windows, Linux)
|
## Manual Install (Windows, Linux)
|
||||||
|
|
||||||
|
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|
||||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from folder_paths import models_dir, user_directory, output_directory
|
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||||
from api_server.services.file_service import FileService
|
from api_server.services.file_service import FileService
|
||||||
import app.logger
|
import app.logger
|
||||||
|
|
||||||
@@ -36,6 +36,13 @@ class InternalRoutes:
|
|||||||
async def get_logs(request):
|
async def get_logs(request):
|
||||||
return web.json_response(app.logger.get_logs())
|
return web.json_response(app.logger.get_logs())
|
||||||
|
|
||||||
|
@self.routes.get('/folder_paths')
|
||||||
|
async def get_folder_paths(request):
|
||||||
|
response = {}
|
||||||
|
for key in folder_names_and_paths:
|
||||||
|
response[key] = folder_names_and_paths[key][0]
|
||||||
|
return web.json_response(response)
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
self._app = web.Application()
|
self._app = web.Application()
|
||||||
|
|||||||
@@ -151,6 +151,15 @@ class FrontendManager:
|
|||||||
return cls.DEFAULT_FRONTEND_PATH
|
return cls.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
|
|
||||||
|
if version.startswith("v"):
|
||||||
|
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||||
|
if os.path.exists(expected_path):
|
||||||
|
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||||
|
return expected_path
|
||||||
|
|
||||||
|
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||||
|
|
||||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||||
release = provider.get_release(version)
|
release = provider.get_release(version)
|
||||||
|
|
||||||
|
|||||||
@@ -10,14 +10,14 @@ def get_logs():
|
|||||||
return "\n".join([formatter.format(x) for x in logs])
|
return "\n".join([formatter.format(x) for x in logs])
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(verbose: bool = False, capacity: int = 300):
|
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||||
global logs
|
global logs
|
||||||
if logs:
|
if logs:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Setup default global logger
|
# Setup default global logger
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler()
|
stream_handler = logging.StreamHandler()
|
||||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
|||||||
@@ -5,17 +5,17 @@ import uuid
|
|||||||
import glob
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from urllib import parse
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from folder_paths import user_directory
|
import folder_paths
|
||||||
from .app_settings import AppSettings
|
from .app_settings import AppSettings
|
||||||
|
|
||||||
default_user = "default"
|
default_user = "default"
|
||||||
users_file = os.path.join(user_directory, "users.json")
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager():
|
class UserManager():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
self.settings = AppSettings(self)
|
self.settings = AppSettings(self)
|
||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
@@ -25,14 +25,17 @@ class UserManager():
|
|||||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(users_file):
|
if os.path.isfile(self.get_users_file()):
|
||||||
with open(users_file) as f:
|
with open(self.get_users_file()) as f:
|
||||||
self.users = json.load(f)
|
self.users = json.load(f)
|
||||||
else:
|
else:
|
||||||
self.users = {}
|
self.users = {}
|
||||||
else:
|
else:
|
||||||
self.users = {"default": "default"}
|
self.users = {"default": "default"}
|
||||||
|
|
||||||
|
def get_users_file(self):
|
||||||
|
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||||
|
|
||||||
def get_request_user_id(self, request):
|
def get_request_user_id(self, request):
|
||||||
user = "default"
|
user = "default"
|
||||||
if args.multi_user and "comfy-user" in request.headers:
|
if args.multi_user and "comfy-user" in request.headers:
|
||||||
@@ -44,7 +47,7 @@ class UserManager():
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
if type == "userdata":
|
if type == "userdata":
|
||||||
root_dir = user_directory
|
root_dir = user_directory
|
||||||
@@ -59,6 +62,10 @@ class UserManager():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if file is not None:
|
if file is not None:
|
||||||
|
# Check if filename is url encoded
|
||||||
|
if "%" in file:
|
||||||
|
file = parse.unquote(file)
|
||||||
|
|
||||||
# prevent leaving /{type}/{user}
|
# prevent leaving /{type}/{user}
|
||||||
path = os.path.abspath(os.path.join(user_root, file))
|
path = os.path.abspath(os.path.join(user_root, file))
|
||||||
if os.path.commonpath((user_root, path)) != user_root:
|
if os.path.commonpath((user_root, path)) != user_root:
|
||||||
@@ -80,8 +87,7 @@ class UserManager():
|
|||||||
|
|
||||||
self.users[user_id] = name
|
self.users[user_id] = name
|
||||||
|
|
||||||
global users_file
|
with open(self.get_users_file(), "w") as f:
|
||||||
with open(users_file, "w") as f:
|
|
||||||
json.dump(self.users, f)
|
json.dump(self.users, f)
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
@@ -112,25 +118,69 @@ class UserManager():
|
|||||||
|
|
||||||
@routes.get("/userdata")
|
@routes.get("/userdata")
|
||||||
async def listuserdata(request):
|
async def listuserdata(request):
|
||||||
|
"""
|
||||||
|
List user data files in a specified directory.
|
||||||
|
|
||||||
|
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||||
|
full file information, and path splitting.
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
- dir (required): The directory to list files from.
|
||||||
|
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||||
|
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||||
|
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- 400: If 'dir' parameter is missing.
|
||||||
|
- 403: If the requested path is not allowed.
|
||||||
|
- 404: If the requested directory does not exist.
|
||||||
|
- 200: JSON response with the list of files or file information.
|
||||||
|
|
||||||
|
The response format depends on the query parameters:
|
||||||
|
- Default: List of relative file paths.
|
||||||
|
- full_info=true: List of dictionaries with file details.
|
||||||
|
- split=true (and full_info=false): List of lists, each containing path components.
|
||||||
|
"""
|
||||||
directory = request.rel_url.query.get('dir', '')
|
directory = request.rel_url.query.get('dir', '')
|
||||||
if not directory:
|
if not directory:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400, text="Directory not provided")
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, directory)
|
path = self.get_request_user_filepath(request, directory)
|
||||||
if not path:
|
if not path:
|
||||||
return web.Response(status=403)
|
return web.Response(status=403, text="Invalid directory")
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404, text="Directory not found")
|
||||||
|
|
||||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||||
results = glob.glob(os.path.join(
|
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||||
glob.escape(path), '**/*'), recursive=recurse)
|
|
||||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
# Use different patterns based on whether we're recursing or not
|
||||||
|
if recurse:
|
||||||
|
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||||
|
else:
|
||||||
|
pattern = os.path.join(glob.escape(path), '*')
|
||||||
|
|
||||||
|
results = glob.glob(pattern, recursive=recurse)
|
||||||
|
|
||||||
|
if full_info:
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||||
|
'size': os.path.getsize(x),
|
||||||
|
'modified': os.path.getmtime(x)
|
||||||
|
} for x in results if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
results = [
|
||||||
|
os.path.relpath(x, path).replace(os.sep, '/')
|
||||||
|
for x in results
|
||||||
|
if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
|
||||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||||
if split_path:
|
if split_path and not full_info:
|
||||||
results = [[x] + x.split(os.sep) for x in results]
|
results = [[x] + x.split('/') for x in results]
|
||||||
|
|
||||||
return web.json_response(results)
|
return web.json_response(results)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
|
control_latent_channels = None,
|
||||||
dtype = None,
|
dtype = None,
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
for _ in range(len(self.joint_blocks)):
|
for _ in range(len(self.joint_blocks)):
|
||||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
|
||||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||||
None,
|
None,
|
||||||
self.patch_size,
|
self.patch_size,
|
||||||
self.in_channels,
|
control_latent_channels,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
strict_img_size=False,
|
strict_img_size=False,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||||
@@ -92,6 +92,8 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|
||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
@@ -134,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
|||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
@@ -169,6 +171,8 @@ parser.add_argument(
|
|||||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if k not in u:
|
if k not in u:
|
||||||
t = sd.pop(k)
|
sd.pop(k)
|
||||||
del t
|
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load(ckpt_path):
|
def load(ckpt_path):
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class StrengthType(Enum):
|
|||||||
LINEAR_UP = 2
|
LINEAR_UP = 2
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self, device=None):
|
def __init__(self):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
@@ -72,20 +72,24 @@ class ControlBase:
|
|||||||
self.compression_ratio = 8
|
self.compression_ratio = 8
|
||||||
self.upscale_algorithm = 'nearest-exact'
|
self.upscale_algorithm = 'nearest-exact'
|
||||||
self.extra_args = {}
|
self.extra_args = {}
|
||||||
|
|
||||||
if device is None:
|
|
||||||
device = comfy.model_management.get_torch_device()
|
|
||||||
self.device = device
|
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
|
self.concat_mask = False
|
||||||
|
self.extra_concat_orig = []
|
||||||
|
self.extra_concat = None
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.timestep_percent_range = timestep_percent_range
|
self.timestep_percent_range = timestep_percent_range
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
|
if vae is None:
|
||||||
|
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
|
self.extra_concat_orig = extra_concat.copy()
|
||||||
|
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||||
|
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
@@ -100,9 +104,9 @@ class ControlBase:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
self.cond_hint = None
|
||||||
self.cond_hint = None
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
@@ -123,6 +127,8 @@ class ControlBase:
|
|||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
c.extra_conds = self.extra_conds.copy()
|
c.extra_conds = self.extra_conds.copy()
|
||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
|
c.concat_mask = self.concat_mask
|
||||||
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@@ -175,8 +181,8 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||||
super().__init__(device)
|
super().__init__()
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
if control_model is not None:
|
if control_model is not None:
|
||||||
@@ -189,6 +195,7 @@ class ControlNet(ControlBase):
|
|||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
self.extra_conds += extra_conds
|
self.extra_conds += extra_conds
|
||||||
self.strength_type = strength_type
|
self.strength_type = strength_type
|
||||||
|
self.concat_mask = concat_mask
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@@ -213,6 +220,9 @@ class ControlNet(ControlBase):
|
|||||||
compression_ratio = self.compression_ratio
|
compression_ratio = self.compression_ratio
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
compression_ratio *= self.vae.downscale_ratio
|
compression_ratio *= self.vae.downscale_ratio
|
||||||
|
else:
|
||||||
|
if self.latent_format is not None:
|
||||||
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
@@ -220,7 +230,15 @@ class ControlNet(ControlBase):
|
|||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
if len(self.extra_concat_orig) > 0:
|
||||||
|
to_concat = []
|
||||||
|
for c in self.extra_concat_orig:
|
||||||
|
c = c.to(self.cond_hint.device)
|
||||||
|
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||||
|
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||||
|
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||||
|
|
||||||
|
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
@@ -319,8 +337,8 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||||
ControlBase.__init__(self, device)
|
ControlBase.__init__(self)
|
||||||
self.control_weights = control_weights
|
self.control_weights = control_weights
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
self.extra_conds += ["y"]
|
self.extra_conds += ["y"]
|
||||||
@@ -376,19 +394,25 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def controlnet_config(sd):
|
def controlnet_config(sd, model_options={}):
|
||||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
controlnet_config = model_config.unet_config
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
|
||||||
operations = comfy.ops.manual_cast
|
operations = model_options.get("custom_operations", None)
|
||||||
else:
|
if operations is None:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
offload_device = comfy.model_management.unet_offload_device()
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
@@ -403,24 +427,29 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
return control_model
|
return control_model
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
concat_mask = False
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||||
|
if control_latent_channels == 17: #inpaint controlnet
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data):
|
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
@@ -430,17 +459,17 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_instantx(sd):
|
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
@@ -449,21 +478,30 @@ def load_controlnet_flux_instantx(sd):
|
|||||||
if union_cnet in new_sd:
|
if union_cnet in new_sd:
|
||||||
num_union_modes = new_sd[union_cnet].shape[0]
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||||
|
concat_mask = False
|
||||||
|
if control_latent_channels == 17:
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.Flux()
|
latent_format = comfy.latent_formats.Flux()
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def convert_mistoline(sd):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||||
|
controlnet_data = state_dict
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data)
|
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
supported_inference_dtypes = None
|
supported_inference_dtypes = None
|
||||||
@@ -518,13 +556,15 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@@ -536,25 +576,36 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
elif key in controlnet_data:
|
elif key in controlnet_data:
|
||||||
prefix = ""
|
prefix = ""
|
||||||
else:
|
else:
|
||||||
net = load_t2i_adapter(controlnet_data)
|
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||||
if net is None:
|
if net is None:
|
||||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
logging.error("error could not detect control model type.")
|
||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
controlnet_config = model_config.unet_config
|
controlnet_config = model_config.unet_config
|
||||||
|
|
||||||
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||||
|
|
||||||
|
if supported_inference_dtypes is None:
|
||||||
|
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||||
|
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
if supported_inference_dtypes is None:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
|
||||||
else:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
operations = model_options.get("custom_operations", None)
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
controlnet_config["operations"] = operations
|
||||||
controlnet_config["dtype"] = unet_dtype
|
controlnet_config["dtype"] = unet_dtype
|
||||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
@@ -590,22 +641,32 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
|
||||||
global_average_pooling = True
|
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
if "global_average_pooling" not in model_options:
|
||||||
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
model_options["global_average_pooling"] = True
|
||||||
|
|
||||||
|
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||||
|
if cnet is None:
|
||||||
|
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||||
|
return cnet
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||||
super().__init__(device)
|
super().__init__()
|
||||||
self.t2i_model = t2i_model
|
self.t2i_model = t2i_model
|
||||||
self.channels_in = channels_in
|
self.channels_in = channels_in
|
||||||
self.control_input = None
|
self.control_input = None
|
||||||
self.compression_ratio = compression_ratio
|
self.compression_ratio = compression_ratio
|
||||||
self.upscale_algorithm = upscale_algorithm
|
self.upscale_algorithm = upscale_algorithm
|
||||||
|
if device is None:
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def scale_image_to(self, width, height):
|
def scale_image_to(self, width, height):
|
||||||
unshuffle_amount = self.t2i_model.unshuffle_amount
|
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||||
@@ -653,7 +714,7 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class NoiseScheduleVP:
|
|||||||
continuous_beta_0=0.1,
|
continuous_beta_0=0.1,
|
||||||
continuous_beta_1=20.,
|
continuous_beta_1=20.,
|
||||||
):
|
):
|
||||||
"""Create a wrapper class for the forward SDE (VP type).
|
r"""Create a wrapper class for the forward SDE (VP type).
|
||||||
|
|
||||||
***
|
***
|
||||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
|
||||||
|
|
||||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||||
mantissa_scaled = torch.where(
|
mantissa_scaled = torch.where(
|
||||||
@@ -41,9 +40,10 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
|||||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
)
|
)
|
||||||
del abs_x
|
|
||||||
|
|
||||||
return sign.to(dtype=dtype)
|
inf = torch.finfo(dtype)
|
||||||
|
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
||||||
|
return sign
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -57,6 +57,11 @@ def stochastic_rounding(value, dtype, seed=0):
|
|||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
generator = torch.Generator(device=value.device)
|
generator = torch.Generator(device=value.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
for i in range(0, value.shape[0], slice_size):
|
||||||
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||||
|
return output
|
||||||
|
|
||||||
return value.to(dtype=dtype)
|
return value.to(dtype=dtype)
|
||||||
|
|||||||
@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
|||||||
return append_zero(sigmas)
|
return append_zero(sigmas)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||||
|
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||||
|
epsilon = 1e-5 # avoid log(0)
|
||||||
|
x = torch.linspace(0, 1, n, device=device)
|
||||||
|
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||||
|
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||||
|
sigmas = clamp(torch.exp(lmb))
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def to_d(x, sigma, denoised):
|
def to_d(x, sigma, denoised):
|
||||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||||
@@ -153,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||||
|
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
@@ -170,6 +183,29 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with Euler method steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||||
|
sigma_down = sigmas[i+1] * downstep_ratio
|
||||||
|
alpha_ip1 = 1 - sigmas[i+1]
|
||||||
|
alpha_down = 1 - sigma_down
|
||||||
|
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
|
# Euler method
|
||||||
|
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||||
|
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||||
|
if sigmas[i + 1] > 0 and eta > 0:
|
||||||
|
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||||
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||||
@@ -1069,7 +1105,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
d = to_d(x, sigma_hat, temp[0])
|
d = to_d(x, sigma_hat, temp[0])
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||||
dt = sigmas[i + 1] - sigma_hat
|
|
||||||
# Euler method
|
# Euler method
|
||||||
x = denoised + d * sigmas[i + 1]
|
x = denoised + d * sigmas[i + 1]
|
||||||
return x
|
return x
|
||||||
@@ -1096,8 +1131,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
d = to_d(x, sigmas[i], temp[0])
|
d = to_d(x, sigmas[i], temp[0])
|
||||||
# Euler method
|
# Euler method
|
||||||
dt = sigma_down - sigmas[i]
|
|
||||||
x = denoised + d * sigma_down
|
x = denoised + d * sigma_down
|
||||||
if sigmas[i + 1] > 0:
|
if sigmas[i + 1] > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
|
temp = [0]
|
||||||
|
def post_cfg_function(args):
|
||||||
|
temp[0] = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigma_down == 0:
|
||||||
|
# Euler method
|
||||||
|
d = to_d(x, sigmas[i], temp[0])
|
||||||
|
x = denoised + d * sigma_down
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2S)
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||||
|
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||||
|
r = 1 / 2
|
||||||
|
h = t_next - t
|
||||||
|
s = t + r * h
|
||||||
|
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||||
|
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||||
|
# Noise addition
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""DPM-Solver++(2M)."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
old_uncond_denoised = None
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||||
|
h = t_next - t
|
||||||
|
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||||
|
else:
|
||||||
|
h_last = t - t_fn(sigmas[i - 1])
|
||||||
|
r = h_last / h
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||||
|
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||||
|
old_uncond_denoised = uncond_denoised
|
||||||
|
return x
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ class LatentFormat:
|
|||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
|
latent_rgb_factors_bias = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.3920, 0.4054, 0.4549],
|
[ 0.3651, 0.4232, 0.4341],
|
||||||
[-0.2634, -0.0196, 0.0653],
|
[-0.2533, -0.0042, 0.1068],
|
||||||
[ 0.0568, 0.1687, -0.0755],
|
[ 0.1076, 0.1111, -0.0362],
|
||||||
[-0.3112, -0.2359, -0.2076]
|
[-0.3165, -0.2492, -0.2188]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
|
||||||
|
|
||||||
self.taesd_decoder_name = "taesdxl_decoder"
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
|
||||||
class SDXL_Playground_2_5(LatentFormat):
|
class SDXL_Playground_2_5(LatentFormat):
|
||||||
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
|
|||||||
self.scale_factor = 1.5305
|
self.scale_factor = 1.5305
|
||||||
self.shift_factor = 0.0609
|
self.shift_factor = 0.0609
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
[-0.0645, 0.0177, 0.1052],
|
[-0.0922, -0.0175, 0.0749],
|
||||||
[ 0.0028, 0.0312, 0.0650],
|
[ 0.0311, 0.0633, 0.0954],
|
||||||
[ 0.1848, 0.0762, 0.0360],
|
[ 0.1994, 0.0927, 0.0458],
|
||||||
[ 0.0944, 0.0360, 0.0889],
|
[ 0.0856, 0.0339, 0.0902],
|
||||||
[ 0.0897, 0.0506, -0.0364],
|
[ 0.0587, 0.0272, -0.0496],
|
||||||
[-0.0020, 0.1203, 0.0284],
|
[-0.0006, 0.1104, 0.0309],
|
||||||
[ 0.0855, 0.0118, 0.0283],
|
[ 0.0978, 0.0306, 0.0427],
|
||||||
[-0.0539, 0.0658, 0.1047],
|
[-0.0042, 0.1038, 0.1358],
|
||||||
[-0.0057, 0.0116, 0.0700],
|
[-0.0194, 0.0020, 0.0669],
|
||||||
[-0.0412, 0.0281, -0.0039],
|
[-0.0488, 0.0130, -0.0268],
|
||||||
[ 0.1106, 0.1171, 0.1220],
|
[ 0.0922, 0.0988, 0.0951],
|
||||||
[-0.0248, 0.0682, -0.0481],
|
[-0.0278, 0.0524, -0.0542],
|
||||||
[ 0.0815, 0.0846, 0.1207],
|
[ 0.0332, 0.0456, 0.0895],
|
||||||
[-0.0120, -0.0055, -0.0867],
|
[-0.0069, -0.0030, -0.0810],
|
||||||
[-0.0749, -0.0634, -0.0456],
|
[-0.0596, -0.0465, -0.0293],
|
||||||
[-0.1418, -0.1457, -0.1259]
|
[-0.1448, -0.1463, -0.1189]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
|
||||||
self.taesd_decoder_name = "taesd3_decoder"
|
self.taesd_decoder_name = "taesd3_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -146,23 +150,24 @@ class Flux(SD3):
|
|||||||
self.scale_factor = 0.3611
|
self.scale_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
self.latent_rgb_factors =[
|
self.latent_rgb_factors =[
|
||||||
[-0.0404, 0.0159, 0.0609],
|
[-0.0346, 0.0244, 0.0681],
|
||||||
[ 0.0043, 0.0298, 0.0850],
|
[ 0.0034, 0.0210, 0.0687],
|
||||||
[ 0.0328, -0.0749, -0.0503],
|
[ 0.0275, -0.0668, -0.0433],
|
||||||
[-0.0245, 0.0085, 0.0549],
|
[-0.0174, 0.0160, 0.0617],
|
||||||
[ 0.0966, 0.0894, 0.0530],
|
[ 0.0859, 0.0721, 0.0329],
|
||||||
[ 0.0035, 0.0399, 0.0123],
|
[ 0.0004, 0.0383, 0.0115],
|
||||||
[ 0.0583, 0.1184, 0.1262],
|
[ 0.0405, 0.0861, 0.0915],
|
||||||
[-0.0191, -0.0206, -0.0306],
|
[-0.0236, -0.0185, -0.0259],
|
||||||
[-0.0324, 0.0055, 0.1001],
|
[-0.0245, 0.0250, 0.1180],
|
||||||
[ 0.0955, 0.0659, -0.0545],
|
[ 0.1008, 0.0755, -0.0421],
|
||||||
[-0.0504, 0.0231, -0.0013],
|
[-0.0515, 0.0201, 0.0011],
|
||||||
[ 0.0500, -0.0008, -0.0088],
|
[ 0.0428, -0.0012, -0.0036],
|
||||||
[ 0.0982, 0.0941, 0.0976],
|
[ 0.0817, 0.0765, 0.0749],
|
||||||
[-0.1233, -0.0280, -0.0897],
|
[-0.1264, -0.0522, -0.1103],
|
||||||
[-0.0005, -0.0530, -0.0020],
|
[-0.0280, -0.0881, -0.0499],
|
||||||
[-0.1273, -0.0932, -0.0680]
|
[-0.1262, -0.0982, -0.0778]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||||
self.taesd_decoder_name = "taef1_decoder"
|
self.taesd_decoder_name = "taef1_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -170,3 +175,30 @@ class Flux(SD3):
|
|||||||
|
|
||||||
def process_out(self, latent):
|
def process_out(self, latent):
|
||||||
return (latent / self.scale_factor) + self.shift_factor
|
return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
|
||||||
|
class Mochi(LatentFormat):
|
||||||
|
latent_channels = 12
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.0
|
||||||
|
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
|
||||||
|
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
|
||||||
|
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
|
||||||
|
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
|
||||||
|
0.959253732819592, 0.8244560132752793, 0.917259975397747,
|
||||||
|
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||||
|
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
|
||||||
|
self.latent_rgb_factors = None #TODO
|
||||||
|
self.taesd_decoder_name = None #TODO
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
|
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
|
return latent * latents_std / self.scale_factor + latents_mean
|
||||||
|
|||||||
@@ -13,9 +13,15 @@ try:
|
|||||||
except:
|
except:
|
||||||
rms_norm_torch = None
|
rms_norm_torch = None
|
||||||
|
|
||||||
def rms_norm(x, weight, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
if rms_norm_torch is not None:
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
if weight is None:
|
||||||
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||||
|
else:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
else:
|
else:
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
if weight is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
#modified to support different types of flux controlnets
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
@@ -12,22 +13,65 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|||||||
from .model import Flux
|
from .model import Flux
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class MistolineCondDownsamplBlock(nn.Module):
|
||||||
|
def __init__(self, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
class MistolineControlnetBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
class ControlNetFlux(Flux):
|
||||||
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
self.main_model_double = 19
|
self.main_model_double = 19
|
||||||
self.main_model_single = 38
|
self.main_model_single = 38
|
||||||
|
|
||||||
|
self.mistoline = mistoline
|
||||||
# add ControlNet blocks
|
# add ControlNet blocks
|
||||||
|
if self.mistoline:
|
||||||
|
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
for _ in range(self.params.depth):
|
for _ in range(self.params.depth):
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
self.controlnet_blocks.append(control_block())
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
self.controlnet_single_blocks = nn.ModuleList([])
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
for _ in range(self.params.depth_single_blocks):
|
for _ in range(self.params.depth_single_blocks):
|
||||||
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
self.controlnet_single_blocks.append(control_block())
|
||||||
|
|
||||||
self.num_union_modes = num_union_modes
|
self.num_union_modes = num_union_modes
|
||||||
self.controlnet_mode_embedder = None
|
self.controlnet_mode_embedder = None
|
||||||
@@ -36,25 +80,33 @@ class ControlNetFlux(Flux):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.latent_input = latent_input
|
self.latent_input = latent_input
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
else:
|
||||||
|
control_latent_channels *= 2 * 2 #patch size
|
||||||
|
|
||||||
|
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
if not self.latent_input:
|
if not self.latent_input:
|
||||||
self.input_hint_block = nn.Sequential(
|
if self.mistoline:
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||||
nn.SiLU(),
|
else:
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
self.input_hint_block = nn.Sequential(
|
||||||
nn.SiLU(),
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
nn.SiLU(),
|
||||||
)
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@@ -73,9 +125,6 @@ class ControlNetFlux(Flux):
|
|||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
if not self.latent_input:
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
img = img + controlnet_cond
|
img = img + controlnet_cond
|
||||||
@@ -131,9 +180,14 @@ class ControlNetFlux(Flux):
|
|||||||
patch_size = 2
|
patch_size = 2
|
||||||
if self.latent_input:
|
if self.latent_input:
|
||||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
elif self.mistoline:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_cond_block(hint)
|
||||||
else:
|
else:
|
||||||
hint = hint * 2.0 - 1.0
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_hint_block(hint)
|
||||||
|
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class Flux(nn.Module):
|
|||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y)
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
@@ -151,8 +151,8 @@ class Flux(nn.Module):
|
|||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
|||||||
541
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
541
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
@@ -0,0 +1,541 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
from .layers import (
|
||||||
|
FeedForward,
|
||||||
|
PatchEmbed,
|
||||||
|
RMSNorm,
|
||||||
|
TimestepEmbedder,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .rope_mixed import (
|
||||||
|
compute_mixed_rotation,
|
||||||
|
create_position_matrix,
|
||||||
|
)
|
||||||
|
from .temporal_rope import apply_rotary_emb_qk_real
|
||||||
|
from .utils import (
|
||||||
|
AttentionPool,
|
||||||
|
modulate,
|
||||||
|
)
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
|
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||||
|
# Normalize and modulate
|
||||||
|
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
||||||
|
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||||
|
|
||||||
|
return x_modulated
|
||||||
|
|
||||||
|
|
||||||
|
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||||
|
# Apply tanh to gate
|
||||||
|
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||||
|
|
||||||
|
# Normalize and apply gated scaling
|
||||||
|
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
||||||
|
|
||||||
|
# Apply residual connection
|
||||||
|
output = x + x_normed
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
class AsymmetricAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_x: int,
|
||||||
|
dim_y: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
update_y: bool = True,
|
||||||
|
out_bias: bool = True,
|
||||||
|
attend_to_padding: bool = False,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim_x = dim_x
|
||||||
|
self.dim_y = dim_y
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim_x // num_heads
|
||||||
|
self.attn_drop = attn_drop
|
||||||
|
self.update_y = update_y
|
||||||
|
self.attend_to_padding = attend_to_padding
|
||||||
|
self.softmax_scale = softmax_scale
|
||||||
|
if dim_x % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input layers.
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
# Project text features to match visual features (dim_y -> dim_x)
|
||||||
|
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Query and key normalization for stability.
|
||||||
|
assert qk_norm
|
||||||
|
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Output layers. y features go back down from dim_x -> dim_y.
|
||||||
|
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||||
|
self.proj_y = (
|
||||||
|
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
|
||||||
|
if update_y
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor, # (B, N, dim_x)
|
||||||
|
y: torch.Tensor, # (B, L, dim_y)
|
||||||
|
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||||
|
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||||
|
crop_y,
|
||||||
|
**rope_rotation,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
rope_cos = rope_rotation.get("rope_cos")
|
||||||
|
rope_sin = rope_rotation.get("rope_sin")
|
||||||
|
# Pre-norm for visual features
|
||||||
|
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||||
|
|
||||||
|
# Process visual features
|
||||||
|
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||||
|
# assert qkv_x.dtype == torch.bfloat16
|
||||||
|
# qkv_x = all_to_all_collect_tokens(
|
||||||
|
# qkv_x, self.num_heads
|
||||||
|
# ) # (3, B, N, local_h, head_dim)
|
||||||
|
|
||||||
|
# Process text features
|
||||||
|
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||||
|
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||||
|
|
||||||
|
q_y = self.q_norm_y(q_y)
|
||||||
|
k_y = self.k_norm_y(k_y)
|
||||||
|
|
||||||
|
# Split qkv_x into q, k, v
|
||||||
|
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||||
|
q_x = self.q_norm_x(q_x)
|
||||||
|
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||||
|
k_x = self.k_norm_x(k_x)
|
||||||
|
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||||
|
|
||||||
|
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
|
||||||
|
xy = optimized_attention(q,
|
||||||
|
k,
|
||||||
|
v, self.num_heads, skip_reshape=True)
|
||||||
|
|
||||||
|
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||||
|
x = self.proj_x(x)
|
||||||
|
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||||
|
o[:, :y.shape[1]] = y
|
||||||
|
|
||||||
|
y = self.proj_y(o)
|
||||||
|
# print("ox", x)
|
||||||
|
# print("oy", y)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class AsymmetricJointBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size_x: int,
|
||||||
|
hidden_size_y: int,
|
||||||
|
num_heads: int,
|
||||||
|
*,
|
||||||
|
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||||
|
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||||
|
update_y: bool = True, # Whether to update text tokens in this block.
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.update_y = update_y
|
||||||
|
self.hidden_size_x = hidden_size_x
|
||||||
|
self.hidden_size_y = hidden_size_y
|
||||||
|
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
|
||||||
|
if self.update_y:
|
||||||
|
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Self-attention:
|
||||||
|
self.attn = AsymmetricAttention(
|
||||||
|
hidden_size_x,
|
||||||
|
hidden_size_y,
|
||||||
|
num_heads=num_heads,
|
||||||
|
update_y=update_y,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MLP.
|
||||||
|
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||||
|
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||||
|
self.mlp_x = FeedForward(
|
||||||
|
in_features=hidden_size_x,
|
||||||
|
hidden_size=mlp_hidden_dim_x,
|
||||||
|
multiple_of=256,
|
||||||
|
ffn_dim_multiplier=None,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MLP for text not needed in last block.
|
||||||
|
if self.update_y:
|
||||||
|
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
||||||
|
self.mlp_y = FeedForward(
|
||||||
|
in_features=hidden_size_y,
|
||||||
|
hidden_size=mlp_hidden_dim_y,
|
||||||
|
multiple_of=256,
|
||||||
|
ffn_dim_multiplier=None,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
**attn_kwargs,
|
||||||
|
):
|
||||||
|
"""Forward pass of a block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, N, dim) tensor of visual tokens
|
||||||
|
c: (B, dim) tensor of conditioned features
|
||||||
|
y: (B, L, dim) tensor of text tokens
|
||||||
|
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: (B, N, dim) tensor of visual tokens after block
|
||||||
|
y: (B, L, dim) tensor of text tokens after block
|
||||||
|
"""
|
||||||
|
N = x.size(1)
|
||||||
|
|
||||||
|
c = F.silu(c)
|
||||||
|
mod_x = self.mod_x(c)
|
||||||
|
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
||||||
|
|
||||||
|
mod_y = self.mod_y(c)
|
||||||
|
if self.update_y:
|
||||||
|
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||||
|
else:
|
||||||
|
scale_msa_y = mod_y
|
||||||
|
|
||||||
|
# Self-attention block.
|
||||||
|
x_attn, y_attn = self.attn(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
scale_x=scale_msa_x,
|
||||||
|
scale_y=scale_msa_y,
|
||||||
|
**attn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert x_attn.size(1) == N
|
||||||
|
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||||
|
if self.update_y:
|
||||||
|
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||||
|
|
||||||
|
# MLP block.
|
||||||
|
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||||
|
if self.update_y:
|
||||||
|
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def ff_block_x(self, x, scale_x, gate_x):
|
||||||
|
x_mod = modulated_rmsnorm(x, scale_x)
|
||||||
|
x_res = self.mlp_x(x_mod)
|
||||||
|
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||||
|
return x
|
||||||
|
|
||||||
|
def ff_block_y(self, y, scale_y, gate_y):
|
||||||
|
y_mod = modulated_rmsnorm(y, scale_y)
|
||||||
|
y_res = self.mlp_y(y_mod)
|
||||||
|
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
patch_size,
|
||||||
|
out_channels,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
|
||||||
|
self.linear = operations.Linear(
|
||||||
|
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
c = F.silu(c)
|
||||||
|
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AsymmDiTJoint(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion model with a Transformer backbone.
|
||||||
|
|
||||||
|
Ingests text embeddings instead of a label.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
patch_size=2,
|
||||||
|
in_channels=4,
|
||||||
|
hidden_size_x=1152,
|
||||||
|
hidden_size_y=1152,
|
||||||
|
depth=48,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio_x=8.0,
|
||||||
|
mlp_ratio_y=4.0,
|
||||||
|
use_t5: bool = False,
|
||||||
|
t5_feat_dim: int = 4096,
|
||||||
|
t5_token_length: int = 256,
|
||||||
|
learn_sigma=True,
|
||||||
|
patch_embed_bias: bool = True,
|
||||||
|
timestep_mlp_bias: bool = True,
|
||||||
|
attend_to_padding: bool = False,
|
||||||
|
timestep_scale: Optional[float] = None,
|
||||||
|
use_extended_posenc: bool = False,
|
||||||
|
posenc_preserve_area: bool = False,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
image_model=None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.learn_sigma = learn_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size_x = hidden_size_x
|
||||||
|
self.hidden_size_y = hidden_size_y
|
||||||
|
self.head_dim = (
|
||||||
|
hidden_size_x // num_heads
|
||||||
|
) # Head dimension and count is determined by visual.
|
||||||
|
self.attend_to_padding = attend_to_padding
|
||||||
|
self.use_extended_posenc = use_extended_posenc
|
||||||
|
self.posenc_preserve_area = posenc_preserve_area
|
||||||
|
self.use_t5 = use_t5
|
||||||
|
self.t5_token_length = t5_token_length
|
||||||
|
self.t5_feat_dim = t5_feat_dim
|
||||||
|
self.rope_theta = (
|
||||||
|
rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
||||||
|
)
|
||||||
|
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_channels,
|
||||||
|
embed_dim=hidden_size_x,
|
||||||
|
bias=patch_embed_bias,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
# Conditionings
|
||||||
|
# Timestep
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_t5:
|
||||||
|
# Caption Pooling (T5)
|
||||||
|
self.t5_y_embedder = AttentionPool(
|
||||||
|
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dense Embedding Projection (T5)
|
||||||
|
self.t5_yproj = operations.Linear(
|
||||||
|
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize pos_frequencies as an empty parameter.
|
||||||
|
self.pos_frequencies = nn.Parameter(
|
||||||
|
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not self.attend_to_padding
|
||||||
|
|
||||||
|
# for depth 48:
|
||||||
|
# b = 0: AsymmetricJointBlock, update_y=True
|
||||||
|
# b = 1: AsymmetricJointBlock, update_y=True
|
||||||
|
# ...
|
||||||
|
# b = 46: AsymmetricJointBlock, update_y=True
|
||||||
|
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
||||||
|
blocks = []
|
||||||
|
for b in range(depth):
|
||||||
|
# Joint multi-modal block
|
||||||
|
update_y = b < depth - 1
|
||||||
|
block = AsymmetricJointBlock(
|
||||||
|
hidden_size_x,
|
||||||
|
hidden_size_y,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio_x=mlp_ratio_x,
|
||||||
|
mlp_ratio_y=mlp_ratio_y,
|
||||||
|
update_y=update_y,
|
||||||
|
attend_to_padding=attend_to_padding,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(block)
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C=12, T, H, W) tensor of visual tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
||||||
|
"""
|
||||||
|
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sigma: torch.Tensor,
|
||||||
|
t5_feat: torch.Tensor,
|
||||||
|
t5_mask: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Prepare input and conditioning embeddings."""
|
||||||
|
# Visual patch embeddings with positional encoding.
|
||||||
|
T, H, W = x.shape[-3:]
|
||||||
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
|
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||||
|
assert x.ndim == 3
|
||||||
|
B = x.size(0)
|
||||||
|
|
||||||
|
|
||||||
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
|
N = T * pH * pW
|
||||||
|
assert x.size(1) == N
|
||||||
|
pos = create_position_matrix(
|
||||||
|
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||||
|
) # (N, 3)
|
||||||
|
rope_cos, rope_sin = compute_mixed_rotation(
|
||||||
|
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||||
|
) # Each are (N, num_heads, dim // 2)
|
||||||
|
|
||||||
|
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||||
|
|
||||||
|
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||||
|
|
||||||
|
c = c_t + t5_y_pool
|
||||||
|
|
||||||
|
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
||||||
|
|
||||||
|
return x, c, y_feat, rope_cos, rope_sin
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: List[torch.Tensor],
|
||||||
|
attention_mask: List[torch.Tensor],
|
||||||
|
num_tokens=256,
|
||||||
|
packed_indices: Dict[str, torch.Tensor] = None,
|
||||||
|
rope_cos: torch.Tensor = None,
|
||||||
|
rope_sin: torch.Tensor = None,
|
||||||
|
control=None, **kwargs
|
||||||
|
):
|
||||||
|
y_feat = context
|
||||||
|
y_mask = attention_mask
|
||||||
|
sigma = timestep
|
||||||
|
"""Forward pass of DiT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
|
sigma: (B,) tensor of noise standard deviations
|
||||||
|
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||||
|
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||||
|
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||||
|
"""
|
||||||
|
B, _, T, H, W = x.shape
|
||||||
|
|
||||||
|
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||||
|
x, sigma, y_feat, y_mask
|
||||||
|
)
|
||||||
|
del y_mask
|
||||||
|
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
x, y_feat = block(
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
y_feat,
|
||||||
|
rope_cos=rope_cos,
|
||||||
|
rope_sin=rope_sin,
|
||||||
|
crop_y=num_tokens,
|
||||||
|
) # (B, M, D), (B, L, D)
|
||||||
|
del y_feat # Final layers don't use dense text features.
|
||||||
|
|
||||||
|
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||||
|
T=T,
|
||||||
|
hp=H // self.patch_size,
|
||||||
|
wp=W // self.patch_size,
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
c=self.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
return -x
|
||||||
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
import math
|
||||||
|
from itertools import repeat
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
# From PyTorch internals
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||||
|
return tuple(x)
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
to_2tuple = _ntuple(2)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
frequency_embedding_size: int = 256,
|
||||||
|
*,
|
||||||
|
bias: bool = True,
|
||||||
|
timestep_scale: Optional[float] = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.timestep_scale = timestep_scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||||
|
freqs.mul_(-math.log(max_period) / half).exp_()
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t, out_dtype):
|
||||||
|
if self.timestep_scale is not None:
|
||||||
|
t = t * self.timestep_scale
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_size: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# keep parameter count and computation constant compared to standard FFN
|
||||||
|
hidden_size = int(2 * hidden_size / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
||||||
|
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||||
|
x = self.w2(F.silu(x) * gate)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
norm_layer: Optional[Callable] = None,
|
||||||
|
flatten: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = to_2tuple(patch_size)
|
||||||
|
self.flatten = flatten
|
||||||
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
|
||||||
|
self.proj = operations.Conv2d(
|
||||||
|
in_chans,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
assert norm_layer is None
|
||||||
|
self.norm = (
|
||||||
|
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, _C, T, H, W = x.shape
|
||||||
|
if not self.dynamic_img_pad:
|
||||||
|
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||||
|
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||||
|
else:
|
||||||
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||||
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||||
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
|
||||||
|
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
# Flatten temporal and spatial dimensions.
|
||||||
|
if not self.flatten:
|
||||||
|
raise NotImplementedError("Must flatten output.")
|
||||||
|
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
|
||||||
|
# import functools
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def centers(start: float, stop, num, dtype=None, device=None):
|
||||||
|
"""linspace through bin centers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start (float): Start of the range.
|
||||||
|
stop (float): End of the range.
|
||||||
|
num (int): Number of points.
|
||||||
|
dtype (torch.dtype): Data type of the points.
|
||||||
|
device (torch.device): Device of the points.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
centers (Tensor): Centers of the bins. Shape: (num,).
|
||||||
|
"""
|
||||||
|
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
||||||
|
return (edges[:-1] + edges[1:]) / 2
|
||||||
|
|
||||||
|
|
||||||
|
# @functools.lru_cache(maxsize=1)
|
||||||
|
def create_position_matrix(
|
||||||
|
T: int,
|
||||||
|
pH: int,
|
||||||
|
pW: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
target_area: float = 36864,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
T: int - Temporal dimension
|
||||||
|
pH: int - Height dimension after patchify
|
||||||
|
pW: int - Width dimension after patchify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pos: [T * pH * pW, 3] - position matrix
|
||||||
|
"""
|
||||||
|
# Create 1D tensors for each dimension
|
||||||
|
t = torch.arange(T, dtype=dtype)
|
||||||
|
|
||||||
|
# Positionally interpolate to area 36864.
|
||||||
|
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
||||||
|
# This automatically scales rope positions when the resolution changes.
|
||||||
|
# We use a large target area so the model is more sensitive
|
||||||
|
# to changes in the learned pos_frequencies matrix.
|
||||||
|
scale = math.sqrt(target_area / (pW * pH))
|
||||||
|
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
||||||
|
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
||||||
|
|
||||||
|
# Use meshgrid to create 3D grids
|
||||||
|
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||||
|
|
||||||
|
# Stack and reshape the grids.
|
||||||
|
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||||
|
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||||
|
pos = pos.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mixed_rotation(
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
pos: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
||||||
|
pos: [N, 3] - position of each token
|
||||||
|
num_heads: int
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||||
|
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||||
|
"""
|
||||||
|
assert freqs.ndim == 3
|
||||||
|
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
||||||
|
freqs_cos = torch.cos(freqs_sum)
|
||||||
|
freqs_sin = torch.sin(freqs_sum)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
|
||||||
|
# Based on Llama3 Implementation.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb_qk_real(
|
||||||
|
xqk: torch.Tensor,
|
||||||
|
freqs_cos: torch.Tensor,
|
||||||
|
freqs_sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
||||||
|
Can be either just query or just key, or both stacked along some batch or * dim.
|
||||||
|
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
||||||
|
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||||
|
"""
|
||||||
|
# Split the last dimension into even and odd parts
|
||||||
|
xqk_even = xqk[..., 0::2]
|
||||||
|
xqk_odd = xqk[..., 1::2]
|
||||||
|
|
||||||
|
# Apply rotation
|
||||||
|
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
||||||
|
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||||
|
|
||||||
|
# Interleave the results back into the original shape
|
||||||
|
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||||
|
return out
|
||||||
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pool tokens in x using mask.
|
||||||
|
|
||||||
|
NOTE: We assume x does not require gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, L, D) tensor of tokens.
|
||||||
|
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pooled: (B, D) tensor of pooled tokens.
|
||||||
|
"""
|
||||||
|
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||||
|
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||||
|
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||||
|
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||||
|
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
output_dim: int = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
spatial_dim (int): Number of tokens in sequence length.
|
||||||
|
embed_dim (int): Dimensionality of input tokens.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
|
||||||
|
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||||
|
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||||
|
|
||||||
|
NOTE: We assume x does not require gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||||
|
"""
|
||||||
|
D = x.size(2)
|
||||||
|
|
||||||
|
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||||
|
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||||
|
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||||
|
|
||||||
|
# Average non-padding token features. These will be used as the query.
|
||||||
|
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||||
|
|
||||||
|
# Concat pooled features to input sequence.
|
||||||
|
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||||
|
|
||||||
|
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||||
|
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||||
|
q = self.to_q(x[:, 0]) # (B, D)
|
||||||
|
|
||||||
|
# Extract heads.
|
||||||
|
head_dim = D // self.num_heads
|
||||||
|
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||||
|
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||||
|
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||||
|
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||||
|
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
|
# Compute attention.
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||||
|
) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
|
# Concatenate heads and run output.
|
||||||
|
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||||
|
x = self.to_out(x)
|
||||||
|
return x
|
||||||
711
comfy/ldm/genmo/vae/model.py
Normal file
711
comfy/ldm/genmo/vae/model.py
Normal file
@@ -0,0 +1,711 @@
|
|||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
from functools import partial
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||||
|
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(t, length=1):
|
||||||
|
return t if isinstance(t, tuple) else ((t,) * length)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNormSpatial(ops.GroupNorm):
|
||||||
|
"""
|
||||||
|
GroupNorm applied per-frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||||
|
# Run group norm in chunks.
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
for b in range(0, B * T, chunk_size):
|
||||||
|
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
||||||
|
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||||
|
|
||||||
|
class PConv3d(ops.Conv3d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
stride: Union[int, Tuple[int, int, int]],
|
||||||
|
causal: bool = True,
|
||||||
|
context_parallel: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.causal = causal
|
||||||
|
self.context_parallel = context_parallel
|
||||||
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
|
stride = cast_tuple(stride, 3)
|
||||||
|
height_pad = (kernel_size[1] - 1) // 2
|
||||||
|
width_pad = (kernel_size[2] - 1) // 2
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=(1, 1, 1),
|
||||||
|
padding=(0, height_pad, width_pad),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# Compute padding amounts.
|
||||||
|
context_size = self.kernel_size[0] - 1
|
||||||
|
if self.causal:
|
||||||
|
pad_front = context_size
|
||||||
|
pad_back = 0
|
||||||
|
else:
|
||||||
|
pad_front = context_size // 2
|
||||||
|
pad_back = context_size - pad_front
|
||||||
|
|
||||||
|
# Apply padding.
|
||||||
|
assert self.padding_mode == "replicate" # DEBUG
|
||||||
|
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||||
|
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1x1(ops.Linear):
|
||||||
|
"""*1x1 Conv implemented with a linear layer."""
|
||||||
|
|
||||||
|
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
|
||||||
|
super().__init__(in_features, out_features, *args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, *] or [B, *, C].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
|
||||||
|
"""
|
||||||
|
x = x.movedim(1, -1)
|
||||||
|
x = super().forward(x)
|
||||||
|
x = x.movedim(-1, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DepthToSpaceTime(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temporal_expansion: int,
|
||||||
|
spatial_expansion: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.temporal_expansion = temporal_expansion
|
||||||
|
self.spatial_expansion = spatial_expansion
|
||||||
|
|
||||||
|
# When printed, this module should show the temporal and spatial expansion factors.
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
|
||||||
|
"""
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
|
||||||
|
st=self.temporal_expansion,
|
||||||
|
sh=self.spatial_expansion,
|
||||||
|
sw=self.spatial_expansion,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cp_rank, _ = cp.get_cp_rank_size()
|
||||||
|
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||||
|
# Drop the first self.temporal_expansion - 1 frames.
|
||||||
|
# This is because we always want the 3x3x3 conv filter to only apply
|
||||||
|
# to the first frame, and the first frame doesn't need to be repeated.
|
||||||
|
assert all(x.shape)
|
||||||
|
x = x[:, :, self.temporal_expansion - 1 :]
|
||||||
|
assert all(x.shape)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def norm_fn(
|
||||||
|
in_channels: int,
|
||||||
|
affine: bool = True,
|
||||||
|
):
|
||||||
|
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
"""Residual block that preserves the spatial dimensions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
*,
|
||||||
|
affine: bool = True,
|
||||||
|
attn_block: Optional[nn.Module] = None,
|
||||||
|
causal: bool = True,
|
||||||
|
prune_bottleneck: bool = False,
|
||||||
|
padding_mode: str,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
assert causal
|
||||||
|
self.stack = nn.Sequential(
|
||||||
|
norm_fn(channels, affine=affine),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
PConv3d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
stride=(1, 1, 1),
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
bias=bias,
|
||||||
|
causal=causal,
|
||||||
|
),
|
||||||
|
norm_fn(channels, affine=affine),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
PConv3d(
|
||||||
|
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
stride=(1, 1, 1),
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
bias=bias,
|
||||||
|
causal=causal,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn_block = attn_block if attn_block else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
x = self.stack(x)
|
||||||
|
x = x + residual
|
||||||
|
del residual
|
||||||
|
|
||||||
|
return self.attn_block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
head_dim: int = 32,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
out_bias: bool = True,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_heads = dim // head_dim
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||||
|
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute temporal self-attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
chunk_size: Chunk size for large tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Output tensor. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
B, _, T, H, W = x.shape
|
||||||
|
|
||||||
|
if T == 1:
|
||||||
|
# No attention for single frame.
|
||||||
|
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||||
|
x = self.out(x)
|
||||||
|
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||||
|
|
||||||
|
# 1D temporal attention.
|
||||||
|
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
|
||||||
|
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||||
|
# Output: x with shape [B, num_heads, t, head_dim]
|
||||||
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
|
||||||
|
|
||||||
|
if self.qk_norm:
|
||||||
|
q = F.normalize(q, p=2, dim=-1)
|
||||||
|
k = F.normalize(k, p=2, dim=-1)
|
||||||
|
|
||||||
|
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||||
|
|
||||||
|
assert x.size(0) == q.size(0)
|
||||||
|
|
||||||
|
x = self.out(x)
|
||||||
|
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
**attn_kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm = norm_fn(dim)
|
||||||
|
self.attn = Attention(dim, **attn_kwargs)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + self.attn(self.norm(x))
|
||||||
|
|
||||||
|
|
||||||
|
class CausalUpsampleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
*,
|
||||||
|
temporal_expansion: int = 2,
|
||||||
|
spatial_expansion: int = 2,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
blocks.append(block_fn(in_channels, **block_kwargs))
|
||||||
|
self.blocks = nn.Sequential(*blocks)
|
||||||
|
|
||||||
|
self.temporal_expansion = temporal_expansion
|
||||||
|
self.spatial_expansion = spatial_expansion
|
||||||
|
|
||||||
|
# Change channels in the final convolution layer.
|
||||||
|
self.proj = Conv1x1(
|
||||||
|
in_channels,
|
||||||
|
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.d2st = DepthToSpaceTime(
|
||||||
|
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.d2st(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||||
|
attn_block = AttentionBlock(channels) if has_attention else None
|
||||||
|
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DownsampleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_res_blocks,
|
||||||
|
*,
|
||||||
|
temporal_reduction=2,
|
||||||
|
spatial_reduction=2,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Downsample block for the VAE encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels.
|
||||||
|
out_channels: Number of output channels.
|
||||||
|
num_res_blocks: Number of residual blocks.
|
||||||
|
temporal_reduction: Temporal reduction factor.
|
||||||
|
spatial_reduction: Spatial reduction factor.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
# Change the channel count in the strided convolution.
|
||||||
|
# This lets the ResBlock have uniform channel count,
|
||||||
|
# as in ConvNeXt.
|
||||||
|
assert in_channels != out_channels
|
||||||
|
layers.append(
|
||||||
|
PConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
# First layer in each block always uses replicate padding
|
||||||
|
padding_mode="replicate",
|
||||||
|
bias=block_kwargs["bias"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers.append(block_fn(out_channels, **block_kwargs))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||||
|
num_freqs = (stop - start) // step
|
||||||
|
assert inputs.ndim == 5
|
||||||
|
C = inputs.size(1)
|
||||||
|
|
||||||
|
# Create Base 2 Fourier features.
|
||||||
|
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
|
||||||
|
assert num_freqs == len(freqs)
|
||||||
|
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
|
||||||
|
C = inputs.shape[1]
|
||||||
|
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
|
||||||
|
|
||||||
|
# Interleaved repeat of input channels to match w.
|
||||||
|
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||||
|
# Scale channels by frequency.
|
||||||
|
h = w * h
|
||||||
|
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
inputs,
|
||||||
|
torch.sin(h),
|
||||||
|
torch.cos(h),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.start = start
|
||||||
|
self.stop = stop
|
||||||
|
self.step = step
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""Add Fourier features to inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Input tensor. Shape: [B, C, T, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
|
||||||
|
"""
|
||||||
|
return add_fourier_features(inputs, self.start, self.stop, self.step)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
out_channels: int = 3,
|
||||||
|
latent_dim: int,
|
||||||
|
base_channels: int,
|
||||||
|
channel_multipliers: List[int],
|
||||||
|
num_res_blocks: List[int],
|
||||||
|
temporal_expansions: Optional[List[int]] = None,
|
||||||
|
spatial_expansions: Optional[List[int]] = None,
|
||||||
|
has_attention: List[bool],
|
||||||
|
output_norm: bool = True,
|
||||||
|
nonlinearity: str = "silu",
|
||||||
|
output_nonlinearity: str = "silu",
|
||||||
|
causal: bool = True,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_channels = latent_dim
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.channel_multipliers = channel_multipliers
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.output_nonlinearity = output_nonlinearity
|
||||||
|
assert nonlinearity == "silu"
|
||||||
|
assert causal
|
||||||
|
|
||||||
|
ch = [mult * base_channels for mult in channel_multipliers]
|
||||||
|
self.num_up_blocks = len(ch) - 1
|
||||||
|
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
|
||||||
|
first_block = [
|
||||||
|
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||||
|
] # Input layer.
|
||||||
|
# First set of blocks preserve channel count.
|
||||||
|
for _ in range(num_res_blocks[-1]):
|
||||||
|
first_block.append(
|
||||||
|
block_fn(
|
||||||
|
ch[-1],
|
||||||
|
has_attention=has_attention[-1],
|
||||||
|
causal=causal,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
blocks.append(nn.Sequential(*first_block))
|
||||||
|
|
||||||
|
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
|
||||||
|
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
|
||||||
|
|
||||||
|
upsample_block_fn = CausalUpsampleBlock
|
||||||
|
|
||||||
|
for i in range(self.num_up_blocks):
|
||||||
|
block = upsample_block_fn(
|
||||||
|
ch[-i - 1],
|
||||||
|
ch[-i - 2],
|
||||||
|
num_res_blocks=num_res_blocks[-i - 2],
|
||||||
|
has_attention=has_attention[-i - 2],
|
||||||
|
temporal_expansion=temporal_expansions[-i - 1],
|
||||||
|
spatial_expansion=spatial_expansions[-i - 1],
|
||||||
|
causal=causal,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
|
assert not output_norm
|
||||||
|
|
||||||
|
# Last block. Preserve channel count.
|
||||||
|
last_block = []
|
||||||
|
for _ in range(num_res_blocks[0]):
|
||||||
|
last_block.append(
|
||||||
|
block_fn(
|
||||||
|
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
blocks.append(nn.Sequential(*last_block))
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
self.output_proj = Conv1x1(ch[0], out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
|
||||||
|
T + 1 = (t - 1) * 4.
|
||||||
|
H = h * 16, W = w * 16.
|
||||||
|
"""
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.output_nonlinearity == "silu":
|
||||||
|
x = F.silu(x, inplace=not self.training)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not self.output_nonlinearity
|
||||||
|
) # StyleGAN3 omits the to-RGB nonlinearity.
|
||||||
|
|
||||||
|
return self.output_proj(x).contiguous()
|
||||||
|
|
||||||
|
class LatentDistribution:
|
||||||
|
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||||
|
"""Initialize latent distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||||
|
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
assert mean.shape == logvar.shape
|
||||||
|
self.mean = mean
|
||||||
|
self.logvar = logvar
|
||||||
|
|
||||||
|
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||||
|
if temperature == 0.0:
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||||
|
else:
|
||||||
|
assert noise.device == self.mean.device
|
||||||
|
noise = noise.to(self.mean.dtype)
|
||||||
|
|
||||||
|
if temperature != 1.0:
|
||||||
|
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||||
|
|
||||||
|
# Just Gaussian sample with no scaling of variance.
|
||||||
|
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels: int,
|
||||||
|
base_channels: int,
|
||||||
|
channel_multipliers: List[int],
|
||||||
|
num_res_blocks: List[int],
|
||||||
|
latent_dim: int,
|
||||||
|
temporal_reductions: List[int],
|
||||||
|
spatial_reductions: List[int],
|
||||||
|
prune_bottlenecks: List[bool],
|
||||||
|
has_attentions: List[bool],
|
||||||
|
affine: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
input_is_conv_1x1: bool = False,
|
||||||
|
padding_mode: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.temporal_reductions = temporal_reductions
|
||||||
|
self.spatial_reductions = spatial_reductions
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.channel_multipliers = channel_multipliers
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
self.fourier_features = FourierFeatures()
|
||||||
|
ch = [mult * base_channels for mult in channel_multipliers]
|
||||||
|
num_down_blocks = len(ch) - 1
|
||||||
|
assert len(num_res_blocks) == num_down_blocks + 2
|
||||||
|
|
||||||
|
layers = (
|
||||||
|
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||||
|
if not input_is_conv_1x1
|
||||||
|
else [Conv1x1(in_channels, ch[0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||||
|
assert len(has_attentions) == num_down_blocks + 2
|
||||||
|
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks[0]):
|
||||||
|
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||||
|
prune_bottlenecks = prune_bottlenecks[1:]
|
||||||
|
has_attentions = has_attentions[1:]
|
||||||
|
|
||||||
|
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
layer = DownsampleBlock(
|
||||||
|
ch[i],
|
||||||
|
ch[i + 1],
|
||||||
|
num_res_blocks=num_res_blocks[i + 1],
|
||||||
|
temporal_reduction=temporal_reductions[i],
|
||||||
|
spatial_reduction=spatial_reductions[i],
|
||||||
|
prune_bottleneck=prune_bottlenecks[i],
|
||||||
|
has_attention=has_attentions[i],
|
||||||
|
affine=affine,
|
||||||
|
bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
|
# Additional blocks.
|
||||||
|
for _ in range(num_res_blocks[-1]):
|
||||||
|
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
# Output layers.
|
||||||
|
self.output_norm = norm_fn(ch[-1])
|
||||||
|
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temporal_downsample(self):
|
||||||
|
return math.prod(self.temporal_reductions)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spatial_downsample(self):
|
||||||
|
return math.prod(self.spatial_reductions)
|
||||||
|
|
||||||
|
def forward(self, x) -> LatentDistribution:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||||
|
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||||
|
logvar: Shape: [B, latent_dim, t, h, w].
|
||||||
|
"""
|
||||||
|
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||||
|
x = self.fourier_features(x)
|
||||||
|
|
||||||
|
x = self.layers(x)
|
||||||
|
|
||||||
|
x = self.output_norm(x)
|
||||||
|
x = F.silu(x, inplace=True)
|
||||||
|
x = self.output_proj(x)
|
||||||
|
|
||||||
|
means, logvar = torch.chunk(x, 2, dim=1)
|
||||||
|
|
||||||
|
assert means.ndim == 5
|
||||||
|
assert logvar.shape == means.shape
|
||||||
|
assert means.size(1) == self.latent_dim
|
||||||
|
|
||||||
|
return LatentDistribution(means, logvar)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoVAE(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=15,
|
||||||
|
base_channels=64,
|
||||||
|
channel_multipliers=[1, 2, 4, 6],
|
||||||
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
|
latent_dim=12,
|
||||||
|
temporal_reductions=[1, 2, 3],
|
||||||
|
spatial_reductions=[2, 2, 2],
|
||||||
|
prune_bottlenecks=[False, False, False, False, False],
|
||||||
|
has_attentions=[False, True, True, True, True],
|
||||||
|
affine=True,
|
||||||
|
bias=True,
|
||||||
|
input_is_conv_1x1=True,
|
||||||
|
padding_mode="replicate"
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
out_channels=3,
|
||||||
|
base_channels=128,
|
||||||
|
channel_multipliers=[1, 2, 4, 6],
|
||||||
|
temporal_expansions=[1, 2, 3],
|
||||||
|
spatial_expansions=[2, 2, 2],
|
||||||
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
|
latent_dim=12,
|
||||||
|
has_attention=[False, False, False, False, False],
|
||||||
|
padding_mode="replicate",
|
||||||
|
output_norm=False,
|
||||||
|
nonlinearity="silu",
|
||||||
|
output_nonlinearity="silu",
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x).mode()
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return self.decoder(x)
|
||||||
@@ -393,6 +393,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
||||||
|
SDP_BATCH_LIMIT = 2**15
|
||||||
|
else:
|
||||||
|
#TODO: other GPUs ?
|
||||||
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
@@ -404,10 +411,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
if SDP_BATCH_LIMIT >= q.shape[0]:
|
||||||
out = (
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out = (
|
||||||
)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||||
|
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
|
||||||
|
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .. import attention
|
from ..attention import optimized_attention
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from .util import timestep_embedding
|
from .util import timestep_embedding
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
@@ -97,7 +97,7 @@ class PatchEmbed(nn.Module):
|
|||||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
# B, C, H, W = x.shape
|
||||||
# if self.img_size is not None:
|
# if self.img_size is not None:
|
||||||
# if self.strict_img_size:
|
# if self.strict_img_size:
|
||||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||||
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
|
|||||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||||
return qkv[0], qkv[1], qkv[2]
|
return qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
def optimized_attention(qkv, num_heads):
|
|
||||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||||
@@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
qkv = self.pre_attention(x)
|
q, k, v = self.pre_attention(x)
|
||||||
x = optimized_attention(
|
x = optimized_attention(
|
||||||
qkv, num_heads=self.num_heads
|
q, k, v, heads=self.num_heads
|
||||||
)
|
)
|
||||||
x = self.post_attention(x)
|
x = self.post_attention(x)
|
||||||
return x
|
return x
|
||||||
@@ -417,6 +415,7 @@ class DismantledBlock(nn.Module):
|
|||||||
scale_mod_only: bool = False,
|
scale_mod_only: bool = False,
|
||||||
swiglu: bool = False,
|
swiglu: bool = False,
|
||||||
qk_norm: Optional[str] = None,
|
qk_norm: Optional[str] = None,
|
||||||
|
x_block_self_attn: bool = False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -440,6 +439,24 @@ class DismantledBlock(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
|
if x_block_self_attn:
|
||||||
|
assert not pre_only
|
||||||
|
assert not scale_mod_only
|
||||||
|
self.x_block_self_attn = True
|
||||||
|
self.attn2 = SelfAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
pre_only=False,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rmsnorm=rmsnorm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.x_block_self_attn = False
|
||||||
if not pre_only:
|
if not pre_only:
|
||||||
if not rmsnorm:
|
if not rmsnorm:
|
||||||
self.norm2 = operations.LayerNorm(
|
self.norm2 = operations.LayerNorm(
|
||||||
@@ -466,7 +483,11 @@ class DismantledBlock(nn.Module):
|
|||||||
multiple_of=256,
|
multiple_of=256,
|
||||||
)
|
)
|
||||||
self.scale_mod_only = scale_mod_only
|
self.scale_mod_only = scale_mod_only
|
||||||
if not scale_mod_only:
|
if x_block_self_attn:
|
||||||
|
assert not pre_only
|
||||||
|
assert not scale_mod_only
|
||||||
|
n_mods = 9
|
||||||
|
elif not scale_mod_only:
|
||||||
n_mods = 6 if not pre_only else 2
|
n_mods = 6 if not pre_only else 2
|
||||||
else:
|
else:
|
||||||
n_mods = 4 if not pre_only else 1
|
n_mods = 4 if not pre_only else 1
|
||||||
@@ -527,14 +548,64 @@ class DismantledBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert self.x_block_self_attn
|
||||||
|
(
|
||||||
|
shift_msa,
|
||||||
|
scale_msa,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
shift_msa2,
|
||||||
|
scale_msa2,
|
||||||
|
gate_msa2,
|
||||||
|
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||||
|
x_norm = self.norm1(x)
|
||||||
|
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||||
|
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||||
|
return qkv, qkv2, (
|
||||||
|
x,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
gate_msa2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
|
||||||
|
assert not self.pre_only
|
||||||
|
attn1 = self.attn.post_attention(attn)
|
||||||
|
attn2 = self.attn2.post_attention(attn2)
|
||||||
|
out1 = gate_msa.unsqueeze(1) * attn1
|
||||||
|
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||||
|
x = x + out1
|
||||||
|
x = x + out2
|
||||||
|
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||||
|
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
assert not self.pre_only
|
assert not self.pre_only
|
||||||
qkv, intermediates = self.pre_attention(x, c)
|
if self.x_block_self_attn:
|
||||||
attn = optimized_attention(
|
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
|
||||||
qkv,
|
attn, _ = optimized_attention(
|
||||||
num_heads=self.attn.num_heads,
|
qkv[0], qkv[1], qkv[2],
|
||||||
)
|
num_heads=self.attn.num_heads,
|
||||||
return self.post_attention(attn, *intermediates)
|
)
|
||||||
|
attn2, _ = optimized_attention(
|
||||||
|
qkv2[0], qkv2[1], qkv2[2],
|
||||||
|
num_heads=self.attn2.num_heads,
|
||||||
|
)
|
||||||
|
return self.post_attention_x(attn, attn2, *intermediates)
|
||||||
|
else:
|
||||||
|
qkv, intermediates = self.pre_attention(x, c)
|
||||||
|
attn = optimized_attention(
|
||||||
|
qkv[0], qkv[1], qkv[2],
|
||||||
|
heads=self.attn.num_heads,
|
||||||
|
)
|
||||||
|
return self.post_attention(attn, *intermediates)
|
||||||
|
|
||||||
|
|
||||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||||
@@ -549,7 +620,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
|||||||
def _block_mixing(context, x, context_block, x_block, c):
|
def _block_mixing(context, x, context_block, x_block, c):
|
||||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||||
|
|
||||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
if x_block.x_block_self_attn:
|
||||||
|
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||||
|
else:
|
||||||
|
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||||
|
|
||||||
o = []
|
o = []
|
||||||
for t in range(3):
|
for t in range(3):
|
||||||
@@ -557,8 +631,8 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
qkv = tuple(o)
|
qkv = tuple(o)
|
||||||
|
|
||||||
attn = optimized_attention(
|
attn = optimized_attention(
|
||||||
qkv,
|
qkv[0], qkv[1], qkv[2],
|
||||||
num_heads=x_block.attn.num_heads,
|
heads=x_block.attn.num_heads,
|
||||||
)
|
)
|
||||||
context_attn, x_attn = (
|
context_attn, x_attn = (
|
||||||
attn[:, : context_qkv[0].shape[1]],
|
attn[:, : context_qkv[0].shape[1]],
|
||||||
@@ -570,7 +644,14 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
context = None
|
context = None
|
||||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
if x_block.x_block_self_attn:
|
||||||
|
attn2 = optimized_attention(
|
||||||
|
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||||
|
heads=x_block.attn2.num_heads,
|
||||||
|
)
|
||||||
|
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||||
|
else:
|
||||||
|
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||||
return context, x
|
return context, x
|
||||||
|
|
||||||
|
|
||||||
@@ -585,8 +666,13 @@ class JointBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
pre_only = kwargs.pop("pre_only")
|
pre_only = kwargs.pop("pre_only")
|
||||||
qk_norm = kwargs.pop("qk_norm", None)
|
qk_norm = kwargs.pop("qk_norm", None)
|
||||||
|
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
self.x_block = DismantledBlock(*args,
|
||||||
|
pre_only=False,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
x_block_self_attn=x_block_self_attn,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return block_mixing(
|
return block_mixing(
|
||||||
@@ -642,7 +728,7 @@ class SelfAttentionContext(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
qkv = self.qkv(x)
|
qkv = self.qkv(x)
|
||||||
q, k, v = split_qkv(qkv, self.dim_head)
|
q, k, v = split_qkv(qkv, self.dim_head)
|
||||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
|
||||||
return self.proj(x)
|
return self.proj(x)
|
||||||
|
|
||||||
class ContextProcessorBlock(nn.Module):
|
class ContextProcessorBlock(nn.Module):
|
||||||
@@ -701,9 +787,12 @@ class MMDiT(nn.Module):
|
|||||||
qk_norm: Optional[str] = None,
|
qk_norm: Optional[str] = None,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
context_processor_layers = None,
|
context_processor_layers = None,
|
||||||
|
x_block_self_attn: bool = False,
|
||||||
|
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||||
context_size = 4096,
|
context_size = 4096,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
final_layer = True,
|
final_layer = True,
|
||||||
|
skip_blocks = False,
|
||||||
dtype = None, #TODO
|
dtype = None, #TODO
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@@ -718,6 +807,7 @@ class MMDiT(nn.Module):
|
|||||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||||
self.pos_embed_offset = pos_embed_offset
|
self.pos_embed_offset = pos_embed_offset
|
||||||
self.pos_embed_max_size = pos_embed_max_size
|
self.pos_embed_max_size = pos_embed_max_size
|
||||||
|
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||||
|
|
||||||
# hidden_size = default(hidden_size, 64 * depth)
|
# hidden_size = default(hidden_size, 64 * depth)
|
||||||
# num_heads = default(num_heads, hidden_size // 64)
|
# num_heads = default(num_heads, hidden_size // 64)
|
||||||
@@ -775,26 +865,28 @@ class MMDiT(nn.Module):
|
|||||||
self.pos_embed = None
|
self.pos_embed = None
|
||||||
|
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.joint_blocks = nn.ModuleList(
|
if not skip_blocks:
|
||||||
[
|
self.joint_blocks = nn.ModuleList(
|
||||||
JointBlock(
|
[
|
||||||
self.hidden_size,
|
JointBlock(
|
||||||
num_heads,
|
self.hidden_size,
|
||||||
mlp_ratio=mlp_ratio,
|
num_heads,
|
||||||
qkv_bias=qkv_bias,
|
mlp_ratio=mlp_ratio,
|
||||||
attn_mode=attn_mode,
|
qkv_bias=qkv_bias,
|
||||||
pre_only=(i == num_blocks - 1) and final_layer,
|
attn_mode=attn_mode,
|
||||||
rmsnorm=rmsnorm,
|
pre_only=(i == num_blocks - 1) and final_layer,
|
||||||
scale_mod_only=scale_mod_only,
|
rmsnorm=rmsnorm,
|
||||||
swiglu=swiglu,
|
scale_mod_only=scale_mod_only,
|
||||||
qk_norm=qk_norm,
|
swiglu=swiglu,
|
||||||
dtype=dtype,
|
qk_norm=qk_norm,
|
||||||
device=device,
|
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
|
||||||
operations=operations
|
dtype=dtype,
|
||||||
)
|
device=device,
|
||||||
for i in range(num_blocks)
|
operations=operations,
|
||||||
]
|
)
|
||||||
)
|
for i in range(num_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
@@ -857,7 +949,9 @@ class MMDiT(nn.Module):
|
|||||||
c_mod: torch.Tensor,
|
c_mod: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if self.register_length > 0:
|
if self.register_length > 0:
|
||||||
context = torch.cat(
|
context = torch.cat(
|
||||||
(
|
(
|
||||||
@@ -869,14 +963,25 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
# context is B, L', D
|
# context is B, L', D
|
||||||
# x is B, L, D
|
# x is B, L, D
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
blocks = len(self.joint_blocks)
|
blocks = len(self.joint_blocks)
|
||||||
for i in range(blocks):
|
for i in range(blocks):
|
||||||
context, x = self.joint_blocks[i](
|
if ("double_block", i) in blocks_replace:
|
||||||
context,
|
def block_wrap(args):
|
||||||
x,
|
out = {}
|
||||||
c=c_mod,
|
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||||
use_checkpoint=self.use_checkpoint,
|
return out
|
||||||
)
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||||
|
context = out["txt"]
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
context, x = self.joint_blocks[i](
|
||||||
|
context,
|
||||||
|
x,
|
||||||
|
c=c_mod,
|
||||||
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
)
|
||||||
if control is not None:
|
if control is not None:
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
@@ -894,6 +999,7 @@ class MMDiT(nn.Module):
|
|||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass of DiT.
|
Forward pass of DiT.
|
||||||
@@ -915,7 +1021,7 @@ class MMDiT(nn.Module):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
context = self.context_embedder(context)
|
context = self.context_embedder(context)
|
||||||
|
|
||||||
x = self.forward_core_with_concat(x, c, context, control)
|
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
|
||||||
|
|
||||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||||
return x[:,:,:hw[-2],:hw[-1]]
|
return x[:,:,:hw[-2],:hw[-1]]
|
||||||
@@ -929,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
|||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|||||||
@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
|
|||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
if "emb_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["emb_patch"]
|
||||||
|
for p in patch:
|
||||||
|
emb = p(emb, self.model_channels, transformer_options)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|||||||
@@ -201,9 +201,13 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
for k in sdk:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
|
clip_g_present = False
|
||||||
for b in range(32): #TODO: clean up
|
for b in range(32): #TODO: clean up
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
@@ -227,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
|
clip_g_present = True
|
||||||
if clip_l_present:
|
if clip_l_present:
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
@@ -242,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
t5_index = 1
|
||||||
key_map[lora_key] = k
|
if clip_g_present:
|
||||||
|
t5_index += 1
|
||||||
|
if clip_l_present:
|
||||||
|
t5_index += 1
|
||||||
|
if t5_index == 2:
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||||
|
t5_index += 1
|
||||||
|
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
@@ -281,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||||
|
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||||
|
|
||||||
diffusers_lora_prefix = ["", "unet."]
|
diffusers_lora_prefix = ["", "unet."]
|
||||||
for p in diffusers_lora_prefix:
|
for p in diffusers_lora_prefix:
|
||||||
@@ -303,6 +317,10 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||||
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
@@ -324,14 +342,15 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
weight_calc.transpose(0, 1)
|
weight_calc.transpose(0, 1)
|
||||||
.reshape(weight_calc.shape[1], -1)
|
.reshape(weight_calc.shape[1], -1)
|
||||||
@@ -400,7 +419,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
@@ -438,7 +457,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -484,7 +503,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -521,28 +540,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from comfy.ldm.cascade.stage_b import StageB
|
|||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
|
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
||||||
import comfy.ldm.aura.mmdit
|
import comfy.ldm.aura.mmdit
|
||||||
import comfy.ldm.hydit.models
|
import comfy.ldm.hydit.models
|
||||||
import comfy.ldm.audio.dit
|
import comfy.ldm.audio.dit
|
||||||
@@ -96,7 +97,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||||
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@@ -244,6 +246,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
|
||||||
|
if self.model_config.scaled_fp8 is not None:
|
||||||
|
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
@@ -713,3 +719,18 @@ class Flux(BaseModel):
|
|||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class GenmoMochi(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|||||||
@@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||||
if context_processor in state_dict_keys:
|
if context_processor in state_dict_keys:
|
||||||
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||||
|
unet_config["x_block_self_attn_layers"] = []
|
||||||
|
for key in state_dict_keys:
|
||||||
|
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
|
||||||
|
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
|
||||||
|
unet_config["x_block_self_attn_layers"].append(int(layer))
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||||
@@ -145,6 +150,34 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "mochi_preview"
|
||||||
|
dit_config["depth"] = 48
|
||||||
|
dit_config["patch_size"] = 2
|
||||||
|
dit_config["num_heads"] = 24
|
||||||
|
dit_config["hidden_size_x"] = 3072
|
||||||
|
dit_config["hidden_size_y"] = 1536
|
||||||
|
dit_config["mlp_ratio_x"] = 4.0
|
||||||
|
dit_config["mlp_ratio_y"] = 4.0
|
||||||
|
dit_config["learn_sigma"] = False
|
||||||
|
dit_config["in_channels"] = 12
|
||||||
|
dit_config["qk_norm"] = True
|
||||||
|
dit_config["qkv_bias"] = False
|
||||||
|
dit_config["out_bias"] = True
|
||||||
|
dit_config["attn_drop"] = 0.0
|
||||||
|
dit_config["patch_embed_bias"] = True
|
||||||
|
dit_config["posenc_preserve_area"] = True
|
||||||
|
dit_config["timestep_mlp_bias"] = True
|
||||||
|
dit_config["attend_to_padding"] = False
|
||||||
|
dit_config["timestep_scale"] = 1000.0
|
||||||
|
dit_config["use_t5"] = True
|
||||||
|
dit_config["t5_feat_dim"] = 4096
|
||||||
|
dit_config["t5_token_length"] = 256
|
||||||
|
dit_config["rope_theta"] = 10000.0
|
||||||
|
return dit_config
|
||||||
|
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -286,9 +319,15 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
return None
|
return None
|
||||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||||
else:
|
|
||||||
return model_config
|
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
|
||||||
|
if scaled_fp8_weight is not None:
|
||||||
|
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||||
|
if model_config.scaled_fp8 == torch.float32:
|
||||||
|
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
return model_config
|
||||||
|
|
||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||||
@@ -501,7 +540,11 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
out_sd = {}
|
out_sd = {}
|
||||||
|
|
||||||
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||||
|
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||||
|
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||||
|
elif 'x_embedder.weight' in state_dict: #Flux
|
||||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||||
@@ -510,10 +553,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
|
||||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
|
||||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
|
||||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
|||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
logging.info("pytorch version: {}".format(torch_version))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -326,7 +326,7 @@ class LoadedModel:
|
|||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model.device == device:
|
if shift_model.device == device:
|
||||||
if shift_model not in keep_loaded:
|
if shift_model not in keep_loaded:
|
||||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
@@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None):
|
|||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
|
if model_params < 0:
|
||||||
|
model_params = 1000000000000000000000
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if args.fp16_unet:
|
if args.fp16_unet:
|
||||||
@@ -645,6 +647,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if fp8_dtype is not None:
|
if fp8_dtype is not None:
|
||||||
|
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||||
|
return fp8_dtype
|
||||||
|
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if model_params * 2 > free_model_memory:
|
if model_params * 2 > free_model_memory:
|
||||||
return fp8_dtype
|
return fp8_dtype
|
||||||
@@ -838,27 +843,21 @@ def force_channels_last():
|
|||||||
#TODO
|
#TODO
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
return r
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
non_blocking = device_supports_non_blocking(device)
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
device_supports_cast = True
|
|
||||||
elif tensor.dtype == torch.bfloat16:
|
|
||||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
|
||||||
device_supports_cast = True
|
|
||||||
elif is_intel_xpu():
|
|
||||||
device_supports_cast = True
|
|
||||||
|
|
||||||
non_blocking = device_should_use_non_blocking(device)
|
|
||||||
|
|
||||||
if device_supports_cast:
|
|
||||||
if copy:
|
|
||||||
if tensor.device == device:
|
|
||||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
@@ -897,7 +896,7 @@ def force_upcast_attention_dtype():
|
|||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -1063,6 +1062,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_fp8_compute(device=None):
|
def supports_fp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 9:
|
if props.major >= 9:
|
||||||
return True
|
return True
|
||||||
@@ -1070,6 +1072,14 @@ def supports_fp8_compute(device=None):
|
|||||||
return False
|
return False
|
||||||
if props.minor < 9:
|
if props.minor < 9:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if WINDOWS:
|
||||||
|
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import comfy.utils
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
from comfy.types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
@@ -88,7 +88,36 @@ class LowVramPatch:
|
|||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
intermediate_dtype = weight.dtype
|
||||||
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||||
|
intermediate_dtype = torch.float32
|
||||||
|
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
||||||
|
|
||||||
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
|
def get_key_weight(model, key):
|
||||||
|
set_func = None
|
||||||
|
convert_func = None
|
||||||
|
op_keys = key.rsplit('.', 1)
|
||||||
|
if len(op_keys) < 2:
|
||||||
|
weight = comfy.utils.get_attr(model, key)
|
||||||
|
else:
|
||||||
|
op = comfy.utils.get_attr(model, op_keys[0])
|
||||||
|
try:
|
||||||
|
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
weight = getattr(op, op_keys[1])
|
||||||
|
if convert_func is not None:
|
||||||
|
weight = comfy.utils.get_attr(model, key)
|
||||||
|
|
||||||
|
return weight, set_func, convert_func
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
@@ -283,17 +312,23 @@ class ModelPatcher:
|
|||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
def get_key_patches(self, filter_prefix=None):
|
||||||
comfy.model_management.unload_model_clones(self)
|
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
p = {}
|
p = {}
|
||||||
for k in model_sd:
|
for k in model_sd:
|
||||||
if filter_prefix is not None:
|
if filter_prefix is not None:
|
||||||
if not k.startswith(filter_prefix):
|
if not k.startswith(filter_prefix):
|
||||||
continue
|
continue
|
||||||
|
bk = self.backup.get(k, None)
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, k)
|
||||||
|
if bk is not None:
|
||||||
|
weight = bk.weight
|
||||||
|
if convert_func is None:
|
||||||
|
convert_func = lambda a, **kwargs: a
|
||||||
|
|
||||||
if k in self.patches:
|
if k in self.patches:
|
||||||
p[k] = [model_sd[k]] + self.patches[k]
|
p[k] = [(weight, convert_func)] + self.patches[k]
|
||||||
else:
|
else:
|
||||||
p[k] = (model_sd[k],)
|
p[k] = [(weight, convert_func)]
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
@@ -309,8 +344,7 @@ class ModelPatcher:
|
|||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update or inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
@@ -320,12 +354,18 @@ class ModelPatcher:
|
|||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
if convert_func is not None:
|
||||||
|
temp_weight = convert_func(temp_weight, inplace=True)
|
||||||
|
|
||||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
if set_func is None:
|
||||||
if inplace_update:
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
if inplace_update:
|
||||||
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
|
|||||||
103
comfy/ops.py
103
comfy/ops.py
@@ -19,20 +19,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.float
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
if device is None or weight.device == device:
|
|
||||||
if not copy:
|
|
||||||
if dtype is None or weight.dtype == dtype:
|
|
||||||
return weight
|
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
@@ -47,12 +39,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
has_function = s.bias_function is not None
|
has_function = s.bias_function is not None
|
||||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
|
|
||||||
has_function = s.weight_function is not None
|
has_function = s.weight_function is not None
|
||||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
@@ -258,20 +250,29 @@ def fp8_linear(self, input):
|
|||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
tensor_2d = False
|
||||||
|
if len(input.shape) == 2:
|
||||||
|
tensor_2d = True
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
if scale_weight is None:
|
if scale_weight is None:
|
||||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
if scale_input is None:
|
else:
|
||||||
scale_input = scale_weight
|
scale_weight = scale_weight.to(input.device)
|
||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
else:
|
||||||
|
scale_input = scale_input.to(input.device)
|
||||||
|
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
@@ -281,7 +282,11 @@ def fp8_linear(self, input):
|
|||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
|
if tensor_2d:
|
||||||
|
return o.reshape(input.shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class fp8_ops(manual_cast):
|
class fp8_ops(manual_cast):
|
||||||
@@ -299,11 +304,63 @@ class fp8_ops(manual_cast):
|
|||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||||
|
class scaled_fp8_op(manual_cast):
|
||||||
|
class Linear(manual_cast.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if override_dtype is not None:
|
||||||
|
kwargs['dtype'] = override_dtype
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if not hasattr(self, 'scale_weight'):
|
||||||
|
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
|
|
||||||
|
if not scale_input:
|
||||||
|
self.scale_input = None
|
||||||
|
|
||||||
|
if not hasattr(self, 'scale_input'):
|
||||||
|
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
if fp8_matrix_mult:
|
||||||
|
out = fp8_linear(self, input)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
|
||||||
|
if weight.numel() < input.numel(): #TODO: optimize
|
||||||
|
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||||
|
|
||||||
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
|
if inplace:
|
||||||
|
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
return weight
|
||||||
|
else:
|
||||||
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
|
||||||
|
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||||
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||||
|
if inplace_update:
|
||||||
|
self.weight.data.copy_(weight)
|
||||||
|
else:
|
||||||
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
|
return scaled_fp8_op
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||||
|
|
||||||
|
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||||
|
return fp8_ops
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
if args.fast:
|
|
||||||
if comfy.model_management.supports_fp8_compute(load_device):
|
|
||||||
return fp8_ops
|
|
||||||
return manual_cast
|
return manual_cast
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from comfy import model_management
|
|||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import scipy
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
@@ -358,11 +358,35 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
|||||||
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||||
|
|
||||||
sigs = []
|
sigs = []
|
||||||
|
last_t = -1
|
||||||
for t in ts:
|
for t in ts:
|
||||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
if t != last_t:
|
||||||
|
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||||
|
last_t = t
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
|
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
|
||||||
|
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
|
||||||
|
if steps == 1:
|
||||||
|
sigma_schedule = [1.0, 0.0]
|
||||||
|
else:
|
||||||
|
if linear_steps is None:
|
||||||
|
linear_steps = steps // 2
|
||||||
|
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||||
|
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
||||||
|
quadratic_steps = steps - linear_steps
|
||||||
|
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||||||
|
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||||||
|
const = quadratic_coef * (linear_steps ** 2)
|
||||||
|
quadratic_sigma_schedule = [
|
||||||
|
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||||||
|
for i in range(linear_steps, steps)
|
||||||
|
]
|
||||||
|
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||||
|
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||||
|
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||||
|
|
||||||
def get_mask_aabb(masks):
|
def get_mask_aabb(masks):
|
||||||
if masks.numel() == 0:
|
if masks.numel() == 0:
|
||||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||||
@@ -570,8 +594,8 @@ class Sampler:
|
|||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis"]
|
"ipndm", "ipndm_v", "deis"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
@@ -729,7 +753,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
|
||||||
|
|
||||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||||
@@ -747,6 +771,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
|||||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||||
elif scheduler_name == "beta":
|
elif scheduler_name == "beta":
|
||||||
sigmas = beta_scheduler(model_sampling, steps)
|
sigmas = beta_scheduler(model_sampling, steps)
|
||||||
|
elif scheduler_name == "linear_quadratic":
|
||||||
|
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
||||||
else:
|
else:
|
||||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|||||||
164
comfy/sd.py
164
comfy/sd.py
@@ -7,6 +7,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
|||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
|
import comfy.ldm.genmo.vae.model
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@@ -25,11 +26,11 @@ import comfy.text_encoders.aura_t5
|
|||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
import comfy.text_encoders.long_clipl
|
import comfy.text_encoders.long_clipl
|
||||||
|
import comfy.text_encoders.genmo
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.supported_models_base
|
|
||||||
import comfy.taesd.taesd
|
import comfy.taesd.taesd
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
@@ -70,14 +71,14 @@ class CLIP:
|
|||||||
clip = target.clip
|
clip = target.clip
|
||||||
tokenizer = target.tokenizer
|
tokenizer = target.tokenizer
|
||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_options.get("load_device", model_management.text_encoder_device())
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
|
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
|
||||||
params['model_options'] = model_options
|
params['model_options'] = model_options
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
@@ -170,6 +171,7 @@ class VAE:
|
|||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
|
self.latent_dim = 2
|
||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
@@ -239,9 +241,22 @@ class VAE:
|
|||||||
self.output_channels = 2
|
self.output_channels = 2
|
||||||
self.upscale_ratio = 2048
|
self.upscale_ratio = 2048
|
||||||
self.downscale_ratio = 2048
|
self.downscale_ratio = 2048
|
||||||
|
self.latent_dim = 1
|
||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae
|
||||||
|
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||||
|
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
|
||||||
|
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
|
||||||
|
self.latent_channels = 12
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||||
|
self.working_dtypes = [torch.float16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@@ -297,6 +312,10 @@ class VAE:
|
|||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||||
|
|
||||||
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
@@ -315,6 +334,7 @@ class VAE:
|
|||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
@@ -322,16 +342,21 @@ class VAE:
|
|||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||||
|
if pixel_samples is None:
|
||||||
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
|
pixel_samples[x:x+batch_number] = out
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
if len(samples_in.shape) == 3:
|
dims = samples_in.ndim - 2
|
||||||
|
if dims == 1:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
else:
|
elif dims == 2:
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
elif dims == 3:
|
||||||
|
pixel_samples = self.decode_tiled_3d(samples_in)
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@@ -343,17 +368,22 @@ class VAE:
|
|||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
|
if self.latent_dim == 3:
|
||||||
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
samples = None
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||||
|
if samples is None:
|
||||||
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
@@ -399,6 +429,7 @@ class CLIPType(Enum):
|
|||||||
STABLE_AUDIO = 4
|
STABLE_AUDIO = 4
|
||||||
HUNYUAN_DIT = 5
|
HUNYUAN_DIT = 5
|
||||||
FLUX = 6
|
FLUX = 6
|
||||||
|
MOCHI = 7
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
@@ -406,8 +437,46 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class TEModel(Enum):
|
||||||
|
CLIP_L = 1
|
||||||
|
CLIP_H = 2
|
||||||
|
CLIP_G = 3
|
||||||
|
T5_XXL = 4
|
||||||
|
T5_XL = 5
|
||||||
|
T5_BASE = 6
|
||||||
|
|
||||||
|
def detect_te_model(sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_G
|
||||||
|
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_H
|
||||||
|
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_L
|
||||||
|
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||||
|
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
|
if weight.shape[-1] == 4096:
|
||||||
|
return TEModel.T5_XXL
|
||||||
|
elif weight.shape[-1] == 2048:
|
||||||
|
return TEModel.T5_XL
|
||||||
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
|
return TEModel.T5_BASE
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def t5xxl_detect(clip_data):
|
||||||
|
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||||
|
|
||||||
|
for sd in clip_data:
|
||||||
|
if weight_name in sd:
|
||||||
|
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -421,64 +490,65 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
clip_target.params = {}
|
clip_target.params = {}
|
||||||
if len(clip_data) == 1:
|
if len(clip_data) == 1:
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
te_model = detect_te_model(clip_data[0])
|
||||||
|
if te_model == TEModel.CLIP_G:
|
||||||
if clip_type == CLIPType.STABLE_CASCADE:
|
if clip_type == CLIPType.STABLE_CASCADE:
|
||||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||||
|
elif clip_type == CLIPType.SD3:
|
||||||
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
elif te_model == TEModel.CLIP_H:
|
||||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
elif te_model == TEModel.T5_XXL:
|
||||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
if clip_type == CLIPType.SD3:
|
||||||
dtype_t5 = weight.dtype
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||||
if weight.shape[-1] == 4096:
|
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif weight.shape[-1] == 2048:
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
elif te_model == TEModel.T5_XL:
|
||||||
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
|
elif te_model == TEModel.T5_BASE:
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
if clip_type == CLIPType.SD3:
|
||||||
if w is not None and w.shape[0] == 248:
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
|
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
elif len(clip_data) == 2:
|
elif len(clip_data) == 2:
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||||
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
||||||
elif clip_type == CLIPType.FLUX:
|
elif clip_type == CLIPType.FLUX:
|
||||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
||||||
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
|
|
||||||
dtype_t5 = None
|
|
||||||
if weight is not None:
|
|
||||||
dtype_t5 = weight.dtype
|
|
||||||
|
|
||||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif len(clip_data) == 3:
|
elif len(clip_data) == 3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
|
tokenizer_data = {}
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@@ -544,11 +614,11 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None:
|
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
unet_dtype = model_options.get("weight_dtype", None)
|
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||||
|
|
||||||
if unet_dtype is None:
|
if unet_dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
@@ -562,7 +632,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model.load_model_weights(sd, diffusion_model_prefix)
|
model.load_model_weights(sd, diffusion_model_prefix)
|
||||||
|
|
||||||
@@ -614,6 +683,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
|
|
||||||
@@ -640,14 +711,21 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||||
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
|
if model_options.get("fp8_optimizations", False):
|
||||||
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
"pooled",
|
"pooled",
|
||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||||
@@ -94,11 +94,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
|
scaled_fp8 = None
|
||||||
|
|
||||||
if operations is None:
|
if operations is None:
|
||||||
operations = comfy.ops.manual_cast
|
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||||
|
else:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||||
|
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@@ -542,6 +551,7 @@ class SD1Tokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -570,6 +580,7 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
|
||||||
|
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SDXLTokenizer:
|
class SDXLTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -40,7 +41,8 @@ class SDXLTokenizer:
|
|||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = set([dtype])
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
token_weight_pairs_l = token_weight_pairs["l"]
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
cut_to = min(l_out.shape[1], g_out.shape[1])
|
||||||
|
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import comfy.text_encoders.sa_t5
|
|||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
|
import comfy.text_encoders.genmo
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@@ -529,12 +530,11 @@ class SD3(supported_models_base.BASE):
|
|||||||
clip_l = True
|
clip_l = True
|
||||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
clip_g = True
|
clip_g = True
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
if t5_key in state_dict:
|
if "dtype_t5" in t5_detect:
|
||||||
t5 = True
|
t5 = True
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
|
||||||
|
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
|
||||||
|
|
||||||
class StableAudio(supported_models_base.BASE):
|
class StableAudio(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@@ -653,11 +653,8 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
dtype_t5 = None
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||||
if t5_key in state_dict:
|
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
|
||||||
|
|
||||||
class FluxSchnell(Flux):
|
class FluxSchnell(Flux):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@@ -674,7 +671,36 @@ class FluxSchnell(Flux):
|
|||||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class GenmoMochi(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "mochi_preview",
|
||||||
|
}
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 6.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Mochi
|
||||||
|
|
||||||
|
memory_usage_factor = 2.0 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.GenmoMochi(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ class BASE:
|
|||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
|
scaled_fp8 = None
|
||||||
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
@@ -71,6 +73,7 @@ class BASE:
|
|||||||
self.unet_config = unet_config.copy()
|
self.unet_config = unet_config.copy()
|
||||||
self.sampling_settings = self.sampling_settings.copy()
|
self.sampling_settings = self.sampling_settings.copy()
|
||||||
self.latent_format = self.latent_format()
|
self.latent_format = self.latent_format()
|
||||||
|
self.optimizations = self.optimizations.copy()
|
||||||
for x in self.unet_extra_config:
|
for x in self.unet_extra_config:
|
||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,21 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import comfy.text_encoders.t5
|
import comfy.text_encoders.t5
|
||||||
|
import comfy.text_encoders.sd3_clip
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
class FluxTokenizer:
|
class FluxTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -38,8 +35,9 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
|
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes = set([dtype, dtype_t5])
|
self.dtypes = set([dtype, dtype_t5])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
@@ -64,8 +62,11 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def flux_clip(dtype_t5=None):
|
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
class FluxClipModel_(FluxClipModel):
|
class FluxClipModel_(FluxClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
|||||||
38
comfy/text_encoders/genmo.py
Normal file
38
comfy/text_encoders/genmo.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.sd3_clip
|
||||||
|
import os
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs["attention_mask"] = True
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
|
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
|
class MochiTEModel_(MochiT5XXL):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
if dtype is None:
|
||||||
|
dtype = dtype_t5
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return MochiTEModel_
|
||||||
@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
|||||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class LongClipModel_(sd1_clip.SDClipModel):
|
class LongClipModel_(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, *args, **kwargs):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||||
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||||
|
|
||||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||||
|
|
||||||
|
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||||
|
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is None:
|
||||||
|
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is not None and w.shape[0] == 248:
|
||||||
|
tokenizer_data = tokenizer_data.copy()
|
||||||
|
model_options = model_options.copy()
|
||||||
|
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||||
|
model_options["clip_l_class"] = LongClipModel_
|
||||||
|
return tokenizer_data, model_options
|
||||||
|
|||||||
@@ -8,9 +8,27 @@ import comfy.model_management
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||||
|
if t5xxl_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def t5_xxl_detect(state_dict, prefix=""):
|
||||||
|
out = {}
|
||||||
|
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
|
||||||
|
if t5_key in state_dict:
|
||||||
|
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
|
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||||
|
if scaled_fp8_key in state_dict:
|
||||||
|
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -20,7 +38,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
@@ -38,11 +57,12 @@ class SD3Tokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_l = None
|
self.clip_l = None
|
||||||
@@ -55,7 +75,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
if t5:
|
if t5:
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
self.t5_attention_mask = t5_attention_mask
|
||||||
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
|
||||||
self.dtypes.add(dtype_t5)
|
self.dtypes.add(dtype_t5)
|
||||||
else:
|
else:
|
||||||
self.t5xxl = None
|
self.t5xxl = None
|
||||||
@@ -85,6 +106,7 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
lg_out = None
|
lg_out = None
|
||||||
pooled = None
|
pooled = None
|
||||||
out = None
|
out = None
|
||||||
|
extra = {}
|
||||||
|
|
||||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||||
if self.clip_l is not None:
|
if self.clip_l is not None:
|
||||||
@@ -95,7 +117,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
if self.clip_g is not None:
|
if self.clip_g is not None:
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
cut_to = min(lg_out.shape[1], g_out.shape[1])
|
||||||
|
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
|
||||||
else:
|
else:
|
||||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||||
else:
|
else:
|
||||||
@@ -108,7 +131,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
if self.t5xxl is not None:
|
if self.t5xxl is not None:
|
||||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||||
|
t5_out, t5_pooled = t5_output[:2]
|
||||||
|
if self.t5_attention_mask:
|
||||||
|
extra["attention_mask"] = t5_output[2]["attention_mask"]
|
||||||
|
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
else:
|
else:
|
||||||
@@ -120,7 +147,7 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
if pooled is None:
|
if pooled is None:
|
||||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
return out, pooled
|
return out, pooled, extra
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@@ -130,8 +157,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
104
comfy/utils.py
104
comfy/utils.py
@@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
|
|||||||
for k in sd.keys():
|
for k in sd.keys():
|
||||||
if k.startswith(prefix):
|
if k.startswith(prefix):
|
||||||
w = sd[k]
|
w = sd[k]
|
||||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||||
|
|
||||||
if len(dtypes) == 0:
|
if len(dtypes) == 0:
|
||||||
return None
|
return None
|
||||||
@@ -316,10 +316,18 @@ MMDIT_MAP_BLOCK = {
|
|||||||
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
||||||
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
||||||
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
||||||
|
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
|
||||||
|
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
|
||||||
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
||||||
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
||||||
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
||||||
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
||||||
|
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
|
||||||
|
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
|
||||||
|
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
|
||||||
|
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
|
||||||
|
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
|
||||||
|
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
|
||||||
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||||
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||||
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
||||||
@@ -349,6 +357,12 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||||
|
|
||||||
|
k = "{}.attn2.".format(block_from)
|
||||||
|
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
|
||||||
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||||
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||||
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||||
|
|
||||||
for k in MMDIT_MAP_BLOCK:
|
for k in MMDIT_MAP_BLOCK:
|
||||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||||
|
|
||||||
@@ -690,9 +704,14 @@ def lanczos(samples, width, height):
|
|||||||
return result.to(samples.device, samples.dtype)
|
return result.to(samples.device, samples.dtype)
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
|
orig_shape = tuple(samples.shape)
|
||||||
|
if len(orig_shape) > 4:
|
||||||
|
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
||||||
|
samples = samples.movedim(2, 1)
|
||||||
|
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
old_width = samples.shape[3]
|
old_width = samples.shape[-1]
|
||||||
old_height = samples.shape[2]
|
old_height = samples.shape[-2]
|
||||||
old_aspect = old_width / old_height
|
old_aspect = old_width / old_height
|
||||||
new_aspect = width / height
|
new_aspect = width / height
|
||||||
x = 0
|
x = 0
|
||||||
@@ -701,48 +720,87 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
|||||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||||
elif old_aspect < new_aspect:
|
elif old_aspect < new_aspect:
|
||||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
||||||
else:
|
else:
|
||||||
s = samples
|
s = samples
|
||||||
|
|
||||||
if upscale_method == "bislerp":
|
if upscale_method == "bislerp":
|
||||||
return bislerp(s, width, height)
|
out = bislerp(s, width, height)
|
||||||
elif upscale_method == "lanczos":
|
elif upscale_method == "lanczos":
|
||||||
return lanczos(s, width, height)
|
out = lanczos(s, width, height)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
|
if len(orig_shape) == 4:
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
|
||||||
|
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
|
||||||
|
|
||||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
||||||
|
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
||||||
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
dims = len(tile)
|
dims = len(tile)
|
||||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
|
||||||
|
if not (isinstance(upscale_amount, (tuple, list))):
|
||||||
|
upscale_amount = [upscale_amount] * dims
|
||||||
|
|
||||||
|
if not (isinstance(overlap, (tuple, list))):
|
||||||
|
overlap = [overlap] * dims
|
||||||
|
|
||||||
|
def get_upscale(dim, val):
|
||||||
|
up = upscale_amount[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return up * val
|
||||||
|
|
||||||
|
def mult_list_upscale(a):
|
||||||
|
out = []
|
||||||
|
for i in range(len(a)):
|
||||||
|
out.append(round(get_upscale(i, a[i])))
|
||||||
|
return out
|
||||||
|
|
||||||
|
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
|
||||||
|
|
||||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
# handle entire input fitting in a single tile
|
||||||
|
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||||
|
output[b:b+1] = function(s).to(output_device)
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
|
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
|
|
||||||
|
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
|
for it in itertools.product(*positions):
|
||||||
s_in = s
|
s_in = s
|
||||||
upscaled = []
|
upscaled = []
|
||||||
|
|
||||||
for d in range(dims):
|
for d in range(dims):
|
||||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
|
||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(pos * upscale_amount))
|
upscaled.append(round(get_upscale(d, pos)))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
|
||||||
for t in range(feather):
|
for d in range(2, dims + 2):
|
||||||
for d in range(2, dims + 2):
|
feather = round(get_upscale(d - 2, overlap[d - 2]))
|
||||||
m = mask.narrow(d, t, 1)
|
for t in range(feather):
|
||||||
m *= ((1.0/feather) * (t + 1))
|
a = (t + 1) / feather
|
||||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
mask.narrow(d, t, 1).mul_(a)
|
||||||
m *= ((1.0/feather) * (t + 1))
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||||
|
|
||||||
o = out
|
o = out
|
||||||
o_d = out_div
|
o_d = out_div
|
||||||
@@ -750,8 +808,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
o += ps * mask
|
o.add_(ps * mask)
|
||||||
o_d += mask
|
o_d.add_(mask)
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence, Mapping
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def include_unique_id_in_input(class_type: str) -> bool:
|
||||||
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
class CacheKeySet:
|
class CacheKeySet:
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
@@ -98,7 +108,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||||
signature.append(node_id)
|
signature.append(node_id)
|
||||||
inputs = node["inputs"]
|
inputs = node["inputs"]
|
||||||
for key in sorted(inputs.keys()):
|
for key in sorted(inputs.keys()):
|
||||||
|
|||||||
@@ -99,30 +99,44 @@ class TopologicalSort:
|
|||||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
self.add_node(from_node_id)
|
if not self.is_cached(from_node_id):
|
||||||
if to_node_id not in self.blocking[from_node_id]:
|
self.add_node(from_node_id)
|
||||||
self.blocking[from_node_id][to_node_id] = {}
|
if to_node_id not in self.blocking[from_node_id]:
|
||||||
self.blockCount[to_node_id] += 1
|
self.blocking[from_node_id][to_node_id] = {}
|
||||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
self.blockCount[to_node_id] += 1
|
||||||
|
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||||
|
|
||||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||||
if unique_id in self.pendingNodes:
|
node_ids = [node_unique_id]
|
||||||
return
|
links = []
|
||||||
self.pendingNodes[unique_id] = True
|
|
||||||
self.blockCount[unique_id] = 0
|
|
||||||
self.blocking[unique_id] = {}
|
|
||||||
|
|
||||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
while len(node_ids) > 0:
|
||||||
for input_name in inputs:
|
unique_id = node_ids.pop()
|
||||||
value = inputs[input_name]
|
if unique_id in self.pendingNodes:
|
||||||
if is_link(value):
|
continue
|
||||||
from_node_id, from_socket = value
|
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
self.pendingNodes[unique_id] = True
|
||||||
continue
|
self.blockCount[unique_id] = 0
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
self.blocking[unique_id] = {}
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
|
||||||
if include_lazy or not is_lazy:
|
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
for input_name in inputs:
|
||||||
|
value = inputs[input_name]
|
||||||
|
if is_link(value):
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
|
continue
|
||||||
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||||
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
|
node_ids.append(from_node_id)
|
||||||
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
self.add_strong_link(*link)
|
||||||
|
|
||||||
|
def is_cached(self, node_id):
|
||||||
|
return False
|
||||||
|
|
||||||
def get_ready_nodes(self):
|
def get_ready_nodes(self):
|
||||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||||
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.output_cache = output_cache
|
self.output_cache = output_cache
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def is_cached(self, node_id):
|
||||||
if self.output_cache.get(from_node_id) is not None:
|
return self.output_cache.get(node_id) is not None
|
||||||
# Nothing to do
|
|
||||||
return
|
|
||||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
|
||||||
|
|
||||||
def stage_node_execution(self):
|
def stage_node_execution(self):
|
||||||
assert self.staged_node_id is None
|
assert self.staged_node_id is None
|
||||||
|
|||||||
@@ -16,14 +16,15 @@ class EmptyLatentAudio:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
def generate(self, seconds):
|
def generate(self, seconds, batch_size):
|
||||||
batch_size = 1
|
|
||||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||||
return ({"samples":latent, "type": "audio"}, )
|
return ({"samples":latent, "type": "audio"}, )
|
||||||
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
|
|||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
|
std[std < 1.0] = 1.0
|
||||||
|
audio /= std
|
||||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||||
|
|
||||||
|
|
||||||
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
|
|||||||
}
|
}
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio:
|
||||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [
|
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||||
f for f in os.listdir(input_dir)
|
|
||||||
if (os.path.isfile(os.path.join(input_dir, f))
|
|
||||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||||
|
|
||||||
CATEGORY = "audio"
|
CATEGORY = "audio"
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||||
|
import nodes
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class SetUnionControlNetType:
|
class SetUnionControlNetType:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
|
|||||||
|
|
||||||
return (control_net,)
|
return (control_net,)
|
||||||
|
|
||||||
|
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
FUNCTION = "apply_inpaint_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
|
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||||
|
extra_concat = []
|
||||||
|
if control_net.concat_mask:
|
||||||
|
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
|
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
||||||
|
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||||
|
extra_concat = [mask]
|
||||||
|
|
||||||
|
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SetUnionControlNetType": SetUnionControlNetType,
|
"SetUnionControlNetType": SetUnionControlNetType,
|
||||||
|
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
|
|||||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
class LaplaceScheduler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
|
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
|
||||||
|
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
|
||||||
|
return (sigmas, )
|
||||||
|
|
||||||
|
|
||||||
class SDTurboScheduler:
|
class SDTurboScheduler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"KarrasScheduler": KarrasScheduler,
|
"KarrasScheduler": KarrasScheduler,
|
||||||
"ExponentialScheduler": ExponentialScheduler,
|
"ExponentialScheduler": ExponentialScheduler,
|
||||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||||
|
"LaplaceScheduler": LaplaceScheduler,
|
||||||
"VPScheduler": VPScheduler,
|
"VPScheduler": VPScheduler,
|
||||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||||
"SDTurboScheduler": SDTurboScheduler,
|
"SDTurboScheduler": SDTurboScheduler,
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class HypernetworkLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy_extras.nodes_post_processing
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent):
|
def reshape_latent_to(target_shape, latent):
|
||||||
@@ -145,6 +146,131 @@ class LatentBatchSeedBehavior:
|
|||||||
|
|
||||||
return (samples_out,)
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentApplyOperation:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"operation": ("LATENT_OPERATION",),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, samples, operation):
|
||||||
|
samples_out = samples.copy()
|
||||||
|
|
||||||
|
s1 = samples["samples"]
|
||||||
|
samples_out["samples"] = operation(latent=s1)
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentApplyOperationCFG:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"operation": ("LATENT_OPERATION",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model, operation):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
def pre_cfg_function(args):
|
||||||
|
conds_out = args["conds_out"]
|
||||||
|
if len(conds_out) == 2:
|
||||||
|
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
|
||||||
|
else:
|
||||||
|
conds_out[0] = operation(latent=conds_out[0])
|
||||||
|
return conds_out
|
||||||
|
|
||||||
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
class LatentOperationTonemapReinhard:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, multiplier):
|
||||||
|
def tonemap_reinhard(latent, **kwargs):
|
||||||
|
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||||
|
normalized_latent = latent / latent_vector_magnitude
|
||||||
|
|
||||||
|
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||||
|
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||||
|
|
||||||
|
top = (std * 5 + mean) * multiplier
|
||||||
|
|
||||||
|
#reinhard
|
||||||
|
latent_vector_magnitude *= (1.0 / top)
|
||||||
|
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
|
||||||
|
new_magnitude *= top
|
||||||
|
|
||||||
|
return normalized_latent * new_magnitude
|
||||||
|
return (tonemap_reinhard,)
|
||||||
|
|
||||||
|
class LatentOperationSharpen:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"sharpen_radius": ("INT", {
|
||||||
|
"default": 9,
|
||||||
|
"min": 1,
|
||||||
|
"max": 31,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
"sigma": ("FLOAT", {
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.1,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.1
|
||||||
|
}),
|
||||||
|
"alpha": ("FLOAT", {
|
||||||
|
"default": 0.1,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 5.0,
|
||||||
|
"step": 0.01
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced/operations"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, sharpen_radius, sigma, alpha):
|
||||||
|
def sharpen(latent, **kwargs):
|
||||||
|
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||||
|
normalized_latent = latent / luminance
|
||||||
|
channels = latent.shape[1]
|
||||||
|
|
||||||
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
|
kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
|
||||||
|
center = kernel_size // 2
|
||||||
|
|
||||||
|
kernel *= alpha * -10
|
||||||
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||||
|
|
||||||
|
padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||||
|
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||||
|
|
||||||
|
return luminance * sharpened
|
||||||
|
return (sharpen,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentAdd": LatentAdd,
|
"LatentAdd": LatentAdd,
|
||||||
"LatentSubtract": LatentSubtract,
|
"LatentSubtract": LatentSubtract,
|
||||||
@@ -152,4 +278,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentInterpolate": LatentInterpolate,
|
"LatentInterpolate": LatentInterpolate,
|
||||||
"LatentBatch": LatentBatch,
|
"LatentBatch": LatentBatch,
|
||||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||||
|
"LatentApplyOperation": LatentApplyOperation,
|
||||||
|
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
||||||
|
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
||||||
|
"LatentOperationSharpen": LatentOperationSharpen,
|
||||||
}
|
}
|
||||||
|
|||||||
119
comfy_extras/nodes_lora_extract.py
Normal file
119
comfy_extras/nodes_lora_extract.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
def extract_lora(diff, rank):
|
||||||
|
conv2d = (len(diff.shape) == 4)
|
||||||
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
out_dim, in_dim = diff.size()[0:2]
|
||||||
|
rank = min(rank, in_dim, out_dim)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
if conv2d_3x3:
|
||||||
|
diff = diff.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
diff = diff.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(diff.float())
|
||||||
|
U = U[:, :rank]
|
||||||
|
S = S[:rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
|
U = U.clamp(low_val, hi_val)
|
||||||
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
return (U, Vh)
|
||||||
|
|
||||||
|
class LORAType(Enum):
|
||||||
|
STANDARD = 0
|
||||||
|
FULL_DIFF = 1
|
||||||
|
|
||||||
|
LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||||
|
"full_diff": LORAType.FULL_DIFF}
|
||||||
|
|
||||||
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||||
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||||
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
weight_diff = sd[k]
|
||||||
|
if lora_type == LORAType.STANDARD:
|
||||||
|
if weight_diff.ndim < 2:
|
||||||
|
if bias_diff:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out = extract_lora(weight_diff, rank)
|
||||||
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||||
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||||
|
except:
|
||||||
|
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||||
|
elif lora_type == LORAType.FULL_DIFF:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
|
||||||
|
elif bias_diff and k.endswith(".bias"):
|
||||||
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
|
return output_sd
|
||||||
|
|
||||||
|
class LoraSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||||
|
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
||||||
|
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||||
|
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
||||||
|
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||||
|
if model_diff is None and text_encoder_diff is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
|
||||||
|
output_sd = {}
|
||||||
|
if model_diff is not None:
|
||||||
|
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
if text_encoder_diff is not None:
|
||||||
|
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LoraSave": LoraSave
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"LoraSave": "Extract and Save Lora"
|
||||||
|
}
|
||||||
26
comfy_extras/nodes_mochi.py
Normal file
26
comfy_extras/nodes_mochi.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import nodes
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class EmptyMochiLatentVideo:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent/mochi"
|
||||||
|
|
||||||
|
def generate(self, width, height, length, batch_size=1):
|
||||||
|
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
|
||||||
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
|
||||||
|
}
|
||||||
@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "model_patches/unet"
|
||||||
|
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
|||||||
@@ -101,10 +101,34 @@ class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
|||||||
|
|
||||||
return {"required": arg_dict}
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["pos_embed."] = argument
|
||||||
|
arg_dict["x_embedder."] = argument
|
||||||
|
arg_dict["context_embedder."] = argument
|
||||||
|
arg_dict["y_embedder."] = argument
|
||||||
|
arg_dict["t_embedder."] = argument
|
||||||
|
|
||||||
|
for i in range(38):
|
||||||
|
arg_dict["joint_blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["final_layer."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSD1": ModelMergeSD1,
|
"ModelMergeSD1": ModelMergeSD1,
|
||||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||||
"ModelMergeSDXL": ModelMergeSDXL,
|
"ModelMergeSDXL": ModelMergeSDXL,
|
||||||
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
||||||
"ModelMergeFlux1": ModelMergeFlux1,
|
"ModelMergeFlux1": ModelMergeFlux1,
|
||||||
|
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class PerpNeg:
|
|||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def patch(self, model, empty_conditioning, neg_scale):
|
def patch(self, model, empty_conditioning, neg_scale):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class PhotoMakerLoader:
|
|||||||
CATEGORY = "_for_testing/photomaker"
|
CATEGORY = "_for_testing/photomaker"
|
||||||
|
|
||||||
def load_photomaker_model(self, photomaker_model_name):
|
def load_photomaker_model(self, photomaker_model_name):
|
||||||
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
|
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
|
||||||
photomaker_model = PhotoMakerIDEncoder()
|
photomaker_model = PhotoMakerIDEncoder()
|
||||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||||
if "id_encoder" in data:
|
if "id_encoder" in data:
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ import comfy.sd
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
class TripleCLIPLoader:
|
class TripleCLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@@ -15,9 +15,9 @@ class TripleCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
|
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
|
|||||||
CATEGORY = "latent/sd3"
|
CATEGORY = "latent/sd3"
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3:
|
||||||
@@ -93,15 +93,75 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
|||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
}}
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
|
class SkipLayerGuidanceSD3:
|
||||||
|
'''
|
||||||
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
|
Experimental implementation by Dango233@StabilityAI.
|
||||||
|
'''
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"model": ("MODEL", ),
|
||||||
|
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||||
|
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "skip_guidance"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/guidance"
|
||||||
|
|
||||||
|
|
||||||
|
def skip_guidance(self, model, layers, scale, start_percent, end_percent):
|
||||||
|
if layers == "" or layers == None:
|
||||||
|
return (model, )
|
||||||
|
# check if layer is comma separated integers
|
||||||
|
def skip(args, extra_args):
|
||||||
|
return args
|
||||||
|
|
||||||
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
model = args["model"]
|
||||||
|
cond_pred = args["cond_denoised"]
|
||||||
|
cond = args["cond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
x = args["input"]
|
||||||
|
model_options = args["model_options"].copy()
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||||
|
model_sampling.percent_to_sigma(start_percent)
|
||||||
|
|
||||||
|
sigma_ = sigma[0].item()
|
||||||
|
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||||
|
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||||
|
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||||
|
return cfg_result
|
||||||
|
|
||||||
|
layers = re.findall(r'\d+', layers)
|
||||||
|
layers = [int(i) for i in layers]
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
"TripleCLIPLoader": TripleCLIPLoader,
|
||||||
"EmptySD3LatentImage": EmptySD3LatentImage,
|
"EmptySD3LatentImage": EmptySD3LatentImage,
|
||||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||||
"ControlNetApplySD3": ControlNetApplySD3,
|
"ControlNetApplySD3": ControlNetApplySD3,
|
||||||
|
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
|
|||||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
CATEGORY = "_for_testing/stable_cascade"
|
CATEGORY = "_for_testing/stable_cascade"
|
||||||
|
|
||||||
def generate(self, image, vae):
|
def generate(self, image, vae):
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class TomePatchModel:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "model_patches/unet"
|
||||||
|
|
||||||
def patch(self, model, ratio):
|
def patch(self, model, ratio):
|
||||||
self.u = None
|
self.u = None
|
||||||
|
|||||||
22
comfy_extras/nodes_torch_compile.py
Normal file
22
comfy_extras/nodes_torch_compile.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class TorchCompileModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"backend": (["inductor", "cudagraphs"],),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model, backend):
|
||||||
|
m = model.clone()
|
||||||
|
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TorchCompileModel": TorchCompileModel,
|
||||||
|
}
|
||||||
@@ -25,7 +25,7 @@ class UpscaleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
|
|||||||
CATEGORY = "loaders/video_models"
|
CATEGORY = "loaders/video_models"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (out[0], out[3], out[2])
|
return (out[0], out[3], out[2])
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
|
|||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class SaveImageWebsocket:
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def IS_CHANGED(s, images):
|
def IS_CHANGED(s, images):
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -179,7 +179,13 @@ def merge_result_data(results, obj):
|
|||||||
# merge node execution results
|
# merge node execution results
|
||||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
if is_list:
|
if is_list:
|
||||||
output.append([x for o in results for x in o[i]])
|
value = []
|
||||||
|
for o in results:
|
||||||
|
if isinstance(o[i], ExecutionBlocker):
|
||||||
|
value.append(o[i])
|
||||||
|
else:
|
||||||
|
value.extend(o[i])
|
||||||
|
output.append(value)
|
||||||
else:
|
else:
|
||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -25,11 +25,16 @@ a111:
|
|||||||
|
|
||||||
#comfyui:
|
#comfyui:
|
||||||
# base_path: path/to/comfyui/
|
# base_path: path/to/comfyui/
|
||||||
|
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
|
||||||
|
# #is_default: true
|
||||||
# checkpoints: models/checkpoints/
|
# checkpoints: models/checkpoints/
|
||||||
# clip: models/clip/
|
# clip: models/clip/
|
||||||
# clip_vision: models/clip_vision/
|
# clip_vision: models/clip_vision/
|
||||||
# configs: models/configs/
|
# configs: models/configs/
|
||||||
# controlnet: models/controlnet/
|
# controlnet: models/controlnet/
|
||||||
|
# diffusion_models: |
|
||||||
|
# models/diffusion_models
|
||||||
|
# models/unet
|
||||||
# embeddings: models/embeddings/
|
# embeddings: models/embeddings/
|
||||||
# loras: models/loras/
|
# loras: models/loras/
|
||||||
# upscale_models: models/upscale_models/
|
# upscale_models: models/upscale_models/
|
||||||
|
|||||||
108
folder_paths.py
108
folder_paths.py
@@ -2,7 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import mimetypes
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Set, List, Dict, Tuple, Literal
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
|
|
||||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||||
@@ -16,7 +18,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
|||||||
|
|
||||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
folder_names_and_paths["text_encoders"] = ([os.path.join(models_dir, "text_encoders"), os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||||
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||||
@@ -44,8 +46,43 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
|
|||||||
|
|
||||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
|
||||||
|
class CacheHelper:
|
||||||
|
"""
|
||||||
|
Helper class for managing file list cache data.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
self.active = False
|
||||||
|
|
||||||
|
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
|
||||||
|
if not self.active:
|
||||||
|
return default
|
||||||
|
return self.cache.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||||
|
if self.active:
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.active = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.active = False
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
cache_helper = CacheHelper()
|
||||||
|
|
||||||
|
extension_mimetypes_cache = {
|
||||||
|
"webp" : "image",
|
||||||
|
}
|
||||||
|
|
||||||
def map_legacy(folder_name: str) -> str:
|
def map_legacy(folder_name: str) -> str:
|
||||||
legacy = {"unet": "diffusion_models"}
|
legacy = {"unet": "diffusion_models",
|
||||||
|
"clip": "text_encoders"}
|
||||||
return legacy.get(folder_name, folder_name)
|
return legacy.get(folder_name, folder_name)
|
||||||
|
|
||||||
if not os.path.exists(input_directory):
|
if not os.path.exists(input_directory):
|
||||||
@@ -78,6 +115,13 @@ def get_input_directory() -> str:
|
|||||||
global input_directory
|
global input_directory
|
||||||
return input_directory
|
return input_directory
|
||||||
|
|
||||||
|
def get_user_directory() -> str:
|
||||||
|
return user_directory
|
||||||
|
|
||||||
|
def set_user_directory(user_dir: str) -> None:
|
||||||
|
global user_directory
|
||||||
|
user_directory = user_dir
|
||||||
|
|
||||||
|
|
||||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||||
def get_directory_by_type(type_name: str) -> str | None:
|
def get_directory_by_type(type_name: str) -> str | None:
|
||||||
@@ -89,6 +133,28 @@ def get_directory_by_type(type_name: str) -> str | None:
|
|||||||
return get_input_directory()
|
return get_input_directory()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
|
filter_files_content_types(files, ["image", "audio", "video"])
|
||||||
|
"""
|
||||||
|
global extension_mimetypes_cache
|
||||||
|
result = []
|
||||||
|
for file in files:
|
||||||
|
extension = file.split('.')[-1]
|
||||||
|
if extension not in extension_mimetypes_cache:
|
||||||
|
mime_type, _ = mimetypes.guess_type(file, strict=False)
|
||||||
|
if not mime_type:
|
||||||
|
continue
|
||||||
|
content_type = mime_type.split('/')[0]
|
||||||
|
extension_mimetypes_cache[extension] = content_type
|
||||||
|
else:
|
||||||
|
content_type = extension_mimetypes_cache[extension]
|
||||||
|
|
||||||
|
if content_type in content_types:
|
||||||
|
result.append(file)
|
||||||
|
return result
|
||||||
|
|
||||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
# otherwise use default_path as base_dir
|
# otherwise use default_path as base_dir
|
||||||
@@ -130,11 +196,14 @@ def exists_annotated_filepath(name) -> bool:
|
|||||||
return os.path.exists(filepath)
|
return os.path.exists(filepath)
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
if is_default:
|
||||||
|
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
|
||||||
|
else:
|
||||||
|
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||||
else:
|
else:
|
||||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||||
|
|
||||||
@@ -166,8 +235,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
|
|||||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||||
for file_name in filenames:
|
for file_name in filenames:
|
||||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
try:
|
||||||
result.append(relative_path)
|
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||||
|
result.append(relative_path)
|
||||||
|
except:
|
||||||
|
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||||
|
continue
|
||||||
|
|
||||||
for d in subdirs:
|
for d in subdirs:
|
||||||
path: str = os.path.join(dirpath, d)
|
path: str = os.path.join(dirpath, d)
|
||||||
@@ -200,6 +273,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
|
full_path = get_full_path(folder_name, filename)
|
||||||
|
if full_path is None:
|
||||||
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
@@ -214,6 +295,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
|
|||||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||||
|
|
||||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||||
|
strong_cache = cache_helper.get(folder_name)
|
||||||
|
if strong_cache is not None:
|
||||||
|
return strong_cache
|
||||||
|
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
@@ -242,6 +327,7 @@ def get_filename_list(folder_name: str) -> list[str]:
|
|||||||
out = get_filename_list_(folder_name)
|
out = get_filename_list_(folder_name)
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
filename_list_cache[folder_name] = out
|
filename_list_cache[folder_name] = out
|
||||||
|
cache_helper.set(folder_name, out)
|
||||||
return list(out[0])
|
return list(out[0])
|
||||||
|
|
||||||
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
||||||
@@ -257,9 +343,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||||
input = input.replace("%width%", str(image_width))
|
input = input.replace("%width%", str(image_width))
|
||||||
input = input.replace("%height%", str(image_height))
|
input = input.replace("%height%", str(image_height))
|
||||||
|
now = time.localtime()
|
||||||
|
input = input.replace("%year%", str(now.tm_year))
|
||||||
|
input = input.replace("%month%", str(now.tm_mon).zfill(2))
|
||||||
|
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
||||||
|
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
||||||
|
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
||||||
|
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||||
return input
|
return input
|
||||||
|
|
||||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
if "%" in filename_prefix:
|
||||||
|
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||||
|
|
||||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import folder_paths
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = 512
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
|
|
||||||
def preview_to_image(latent_image):
|
def preview_to_image(latent_image):
|
||||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
@@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self, latent_rgb_factors):
|
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||||
|
self.latent_rgb_factors_bias = None
|
||||||
|
if latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||||
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
if self.latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||||
|
|
||||||
|
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||||
|
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||||
|
|
||||||
return preview_to_image(latent_image)
|
return preview_to_image(latent_image)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
if latent_format.latent_rgb_factors is not None:
|
if latent_format.latent_rgb_factors is not None:
|
||||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
def prepare_callback(model, steps, x0_output_dict=None):
|
def prepare_callback(model, steps, x0_output_dict=None):
|
||||||
|
|||||||
43
main.py
43
main.py
@@ -9,7 +9,7 @@ from comfy.cli_args import args
|
|||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
setup_logger(verbose=args.verbose)
|
setup_logger(log_level=args.verbose)
|
||||||
|
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
@@ -63,6 +63,7 @@ import threading
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import utils.extra_config
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
@@ -85,7 +86,6 @@ if args.windows_standalone_build:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import yaml
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@@ -160,7 +160,10 @@ def prompt_worker(q, server):
|
|||||||
need_gc = False
|
need_gc = False
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
addresses = []
|
||||||
|
for addr in address.split(","):
|
||||||
|
addresses.append((addr, port))
|
||||||
|
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
@@ -180,27 +183,6 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def load_extra_path_config(yaml_path):
|
|
||||||
with open(yaml_path, 'r') as stream:
|
|
||||||
config = yaml.safe_load(stream)
|
|
||||||
for c in config:
|
|
||||||
conf = config[c]
|
|
||||||
if conf is None:
|
|
||||||
continue
|
|
||||||
base_path = None
|
|
||||||
if "base_path" in conf:
|
|
||||||
base_path = conf.pop("base_path")
|
|
||||||
for x in conf:
|
|
||||||
for y in conf[x].split("\n"):
|
|
||||||
if len(y) == 0:
|
|
||||||
continue
|
|
||||||
full_path = y
|
|
||||||
if base_path is not None:
|
|
||||||
full_path = os.path.join(base_path, full_path)
|
|
||||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
|
||||||
folder_paths.add_model_folder_path(x, full_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
@@ -222,11 +204,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
load_extra_path_config(extra_model_paths_config_path)
|
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||||
|
|
||||||
if args.extra_model_paths_config:
|
if args.extra_model_paths_config:
|
||||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||||
load_extra_path_config(config_path)
|
utils.extra_config.load_extra_path_config(config_path)
|
||||||
|
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||||
|
|
||||||
@@ -247,21 +229,30 @@ if __name__ == "__main__":
|
|||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||||
|
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
logging.info(f"Setting input directory to: {input_dir}")
|
logging.info(f"Setting input directory to: {input_dir}")
|
||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
|
if args.user_directory:
|
||||||
|
user_dir = os.path.abspath(args.user_directory)
|
||||||
|
logging.info(f"Setting user directory to: {user_dir}")
|
||||||
|
folder_paths.set_user_directory(user_dir)
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
def startup_server(scheme, address, port):
|
def startup_server(scheme, address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
if os.name == 'nt' and address == '0.0.0.0':
|
if os.name == 'nt' and address == '0.0.0.0':
|
||||||
address = '127.0.0.1'
|
address = '127.0.0.1'
|
||||||
|
if ':' in address:
|
||||||
|
address = "[{}]".format(address)
|
||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# model_manager/__init__.py
|
# model_manager/__init__.py
|
||||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
#NOTE: This was an experiment and WILL BE REMOVED
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -17,6 +18,7 @@ class DownloadStatusType(Enum):
|
|||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadModelStatus():
|
class DownloadModelStatus():
|
||||||
status: str
|
status: str
|
||||||
@@ -38,10 +40,12 @@ class DownloadModelStatus():
|
|||||||
"already_existed": self.already_existed
|
"already_existed": self.already_existed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_directory: str,
|
||||||
|
folder_path: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
"""
|
"""
|
||||||
@@ -54,23 +58,17 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
The name of the model file to be downloaded. This will be the filename on disk.
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
model_url (str):
|
model_url (str):
|
||||||
The URL from which to download the model.
|
The URL from which to download the model.
|
||||||
model_sub_directory (str):
|
model_directory (str):
|
||||||
The subdirectory within the main models directory where the model
|
The subdirectory within the main models directory where the model
|
||||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
An asynchronous function to call with progress updates.
|
An asynchronous function to call with progress updates.
|
||||||
|
folder_path (str);
|
||||||
|
Path to which model folder should be used as the root.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DownloadModelStatus: The result of the download operation.
|
DownloadModelStatus: The result of the download operation.
|
||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
"Invalid model subdirectory",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validate_filename(model_name):
|
if not validate_filename(model_name):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
@@ -79,52 +77,67 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
if not model_directory in folder_names_and_paths:
|
||||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not folder_path in get_folder_paths(model_directory):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = create_model_path(model_name, folder_path)
|
||||||
|
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||||
if existing_file:
|
if existing_file:
|
||||||
return existing_file
|
return existing_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logging.info(f"Downloading {model_name} from {model_url}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in downloading model: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
os.makedirs(full_model_dir, exist_ok=True)
|
file_path = os.path.join(folder_path, model_name)
|
||||||
file_path = os.path.join(full_model_dir, model_name)
|
|
||||||
|
|
||||||
# Ensure the resulting path is still within the base directory
|
# Ensure the resulting path is still within the base directory
|
||||||
abs_file_path = os.path.abspath(file_path)
|
abs_file_path = os.path.abspath(file_path)
|
||||||
abs_base_dir = os.path.abspath(str(models_base_dir))
|
abs_base_dir = os.path.abspath(folder_path)
|
||||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
relative_path = '/'.join([model_directory, model_name])
|
|
||||||
return file_path, relative_path
|
|
||||||
|
|
||||||
async def check_file_exists(file_path: str,
|
async def check_file_exists(file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
) -> Optional[DownloadModelStatus]:
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -133,7 +146,6 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
file_path: str,
|
file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str,
|
|
||||||
interval: float = 1.0) -> DownloadModelStatus:
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
try:
|
try:
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
@@ -144,10 +156,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
with open(file_path, 'wb') as f:
|
temp_file_path = file_path + '.tmp'
|
||||||
|
with open(temp_file_path, 'wb') as f:
|
||||||
chunk_iterator = response.content.iter_chunked(8192)
|
chunk_iterator = response.content.iter_chunked(8192)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -160,49 +173,30 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
if time.time() - last_update_time >= interval:
|
if time.time() - last_update_time >= interval:
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
|
os.rename(temp_file_path, file_path)
|
||||||
|
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
async def handle_download_error(e: Exception,
|
async def handle_download_error(e: Exception,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||||
relative_path: str) -> DownloadModelStatus:
|
) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that the model subdirectory is safe to install into.
|
|
||||||
Must not contain relative paths, nested paths or special characters
|
|
||||||
other than underscores and hyphens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_subdirectory (str): The subdirectory for the specific model type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the subdirectory is safe, False otherwise.
|
|
||||||
"""
|
|
||||||
if len(model_subdirectory) > 50:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def validate_filename(filename: str)-> bool:
|
def validate_filename(filename: str)-> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
0
models/text_encoders/put_text_encoder_files_here
Normal file
0
models/text_encoders/put_text_encoder_files_here
Normal file
84
nodes.py
84
nodes.py
@@ -281,7 +281,10 @@ class VAEDecode:
|
|||||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
return (vae.decode(samples["samples"]), )
|
images = vae.decode(samples["samples"])
|
||||||
|
if len(images.shape) == 5: #Combine batches
|
||||||
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
|
return (images, )
|
||||||
|
|
||||||
class VAEDecodeTiled:
|
class VAEDecodeTiled:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -511,10 +514,11 @@ class CheckpointLoader:
|
|||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def load_checkpoint(self, config_name, ckpt_name):
|
def load_checkpoint(self, config_name, ckpt_name):
|
||||||
config_path = folder_paths.get_full_path("configs", config_name)
|
config_path = folder_paths.get_full_path("configs", config_name)
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
class CheckpointLoaderSimple:
|
class CheckpointLoaderSimple:
|
||||||
@@ -535,7 +539,7 @@ class CheckpointLoaderSimple:
|
|||||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name):
|
def load_checkpoint(self, ckpt_name):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
@@ -577,7 +581,7 @@ class unCLIPCheckpointLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -624,7 +628,7 @@ class LoraLoader:
|
|||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
return (model, clip)
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||||
lora = None
|
lora = None
|
||||||
if self.loaded_lora is not None:
|
if self.loaded_lora is not None:
|
||||||
if self.loaded_lora[0] == lora_path:
|
if self.loaded_lora[0] == lora_path:
|
||||||
@@ -703,11 +707,11 @@ class VAELoader:
|
|||||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||||
|
|
||||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
|
||||||
for k in enc:
|
for k in enc:
|
||||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||||
|
|
||||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
|
||||||
for k in dec:
|
for k in dec:
|
||||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||||
|
|
||||||
@@ -738,7 +742,7 @@ class VAELoader:
|
|||||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
@@ -754,7 +758,7 @@ class ControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, control_net_name):
|
def load_controlnet(self, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@@ -770,7 +774,7 @@ class DiffControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, model, control_net_name):
|
def load_controlnet(self, model, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@@ -786,6 +790,7 @@ class ControlNetApply:
|
|||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
FUNCTION = "apply_controlnet"
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
DEPRECATED = True
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||||
@@ -815,7 +820,10 @@ class ControlNetApplyAdvanced:
|
|||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
},
|
||||||
|
"optional": {"vae": ("VAE", ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@@ -823,7 +831,7 @@ class ControlNetApplyAdvanced:
|
|||||||
|
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
|
||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (positive, negative)
|
return (positive, negative)
|
||||||
|
|
||||||
@@ -840,7 +848,7 @@ class ControlNetApplyAdvanced:
|
|||||||
if prev_cnet in cnets:
|
if prev_cnet in cnets:
|
||||||
c_net = cnets[prev_cnet]
|
c_net = cnets[prev_cnet]
|
||||||
else:
|
else:
|
||||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
|
||||||
c_net.set_previous_controlnet(prev_cnet)
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
cnets[prev_cnet] = c_net
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
@@ -856,7 +864,7 @@ class UNETLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
@@ -867,18 +875,21 @@ class UNETLoader:
|
|||||||
model_options = {}
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
model_options["dtype"] = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||||
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
model_options["fp8_optimizations"] = True
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
model_options["dtype"] = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
|
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@@ -892,18 +903,20 @@ class CLIPLoader:
|
|||||||
clip_type = comfy.sd.CLIPType.SD3
|
clip_type = comfy.sd.CLIPType.SD3
|
||||||
elif type == "stable_audio":
|
elif type == "stable_audio":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||||
|
elif type == "mochi":
|
||||||
|
clip_type = comfy.sd.CLIPType.MOCHI
|
||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class DualCLIPLoader:
|
class DualCLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("clip"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux"], ),
|
"type": (["sdxl", "sd3", "flux"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
@@ -912,8 +925,8 @@ class DualCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type):
|
def load_clip(self, clip_name1, clip_name2, type):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
if type == "sdxl":
|
if type == "sdxl":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
elif type == "sd3":
|
elif type == "sd3":
|
||||||
@@ -935,7 +948,7 @@ class CLIPVisionLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
||||||
clip_vision = comfy.clip_vision.load(clip_path)
|
clip_vision = comfy.clip_vision.load(clip_path)
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
@@ -965,7 +978,7 @@ class StyleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_style_model(self, style_model_name):
|
def load_style_model(self, style_model_name):
|
||||||
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
|
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
|
||||||
style_model = comfy.sd.load_style_model(style_model_path)
|
style_model = comfy.sd.load_style_model(style_model_path)
|
||||||
return (style_model,)
|
return (style_model,)
|
||||||
|
|
||||||
@@ -1030,7 +1043,7 @@ class GLIGENLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_gligen(self, gligen_name):
|
def load_gligen(self, gligen_name):
|
||||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
|
||||||
gligen = comfy.sd.load_gligen(gligen_path)
|
gligen = comfy.sd.load_gligen(gligen_path)
|
||||||
return (gligen,)
|
return (gligen,)
|
||||||
|
|
||||||
@@ -1171,10 +1184,10 @@ class LatentUpscale:
|
|||||||
|
|
||||||
if width == 0:
|
if width == 0:
|
||||||
height = max(64, height)
|
height = max(64, height)
|
||||||
width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
|
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
|
||||||
elif height == 0:
|
elif height == 0:
|
||||||
width = max(64, width)
|
width = max(64, width)
|
||||||
height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
|
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
|
||||||
else:
|
else:
|
||||||
width = max(64, width)
|
width = max(64, width)
|
||||||
height = max(64, height)
|
height = max(64, height)
|
||||||
@@ -1196,8 +1209,8 @@ class LatentUpscaleBy:
|
|||||||
|
|
||||||
def upscale(self, samples, upscale_method, scale_by):
|
def upscale(self, samples, upscale_method, scale_by):
|
||||||
s = samples.copy()
|
s = samples.copy()
|
||||||
width = round(samples["samples"].shape[3] * scale_by)
|
width = round(samples["samples"].shape[-1] * scale_by)
|
||||||
height = round(samples["samples"].shape[2] * scale_by)
|
height = round(samples["samples"].shape[-2] * scale_by)
|
||||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
@@ -1916,8 +1929,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet",
|
"ControlNetApply": "Apply ControlNet (OLD)",
|
||||||
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
|
"ControlNetApplyAdvanced": "Apply ControlNet",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||||
@@ -1944,6 +1957,12 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImageInvert": "Invert Image",
|
"ImageInvert": "Invert Image",
|
||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
"ImageBatch": "Batch Images",
|
"ImageBatch": "Batch Images",
|
||||||
|
"ImageCrop": "Image Crop",
|
||||||
|
"ImageBlend": "Image Blend",
|
||||||
|
"ImageBlur": "Image Blur",
|
||||||
|
"ImageQuantize": "Image Quantize",
|
||||||
|
"ImageSharpen": "Image Sharpen",
|
||||||
|
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||||
# _for_testing
|
# _for_testing
|
||||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
@@ -2101,6 +2120,9 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_controlnet.py",
|
"nodes_controlnet.py",
|
||||||
"nodes_hunyuan.py",
|
"nodes_hunyuan.py",
|
||||||
"nodes_flux.py",
|
"nodes_flux.py",
|
||||||
|
"nodes_lora_extract.py",
|
||||||
|
"nodes_torch_compile.py",
|
||||||
|
"nodes_mochi.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ def get_images(ws, prompt):
|
|||||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||||
break #Execution is done
|
break #Execution is done
|
||||||
else:
|
else:
|
||||||
|
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
|
||||||
|
# bytesIO = BytesIO(out[8:])
|
||||||
|
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
|
||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = get_history(prompt_id)[prompt_id]
|
history = get_history(prompt_id)[prompt_id]
|
||||||
@@ -151,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
|||||||
130
server.py
130
server.py
@@ -12,6 +12,8 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
|
import socket
|
||||||
|
import ipaddress
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -38,7 +40,7 @@ class BinaryEventTypes:
|
|||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||||
logging.warning("send error: {}".format(err))
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
def get_comfyui_version():
|
def get_comfyui_version():
|
||||||
@@ -80,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
|
|
||||||
return cors_middleware
|
return cors_middleware
|
||||||
|
|
||||||
|
def is_loopback(host):
|
||||||
|
if host is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if ipaddress.ip_address(host).is_loopback:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loopback = False
|
||||||
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
try:
|
||||||
|
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||||
|
for family, _, _, _, sockaddr in r:
|
||||||
|
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||||
|
return loopback
|
||||||
|
else:
|
||||||
|
loopback = True
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return loopback
|
||||||
|
|
||||||
|
|
||||||
|
def create_origin_only_middleware():
|
||||||
|
@web.middleware
|
||||||
|
async def origin_only_middleware(request: web.Request, handler):
|
||||||
|
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||||
|
#in that case the Host and Origin hostnames won't match
|
||||||
|
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||||
|
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||||
|
host = request.headers['Host']
|
||||||
|
origin = request.headers['Origin']
|
||||||
|
host_domain = host.lower()
|
||||||
|
parsed = urllib.parse.urlparse(origin)
|
||||||
|
origin_domain = parsed.netloc.lower()
|
||||||
|
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||||
|
|
||||||
|
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||||
|
loopback = is_loopback(host_domain_parsed.hostname)
|
||||||
|
|
||||||
|
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||||
|
host_domain = host_domain_parsed.hostname
|
||||||
|
if host_domain_parsed.port is None:
|
||||||
|
origin_domain = parsed.hostname
|
||||||
|
|
||||||
|
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||||
|
if host_domain != origin_domain:
|
||||||
|
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
response = web.Response()
|
||||||
|
else:
|
||||||
|
response = await handler(request)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return origin_only_middleware
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
@@ -99,6 +163,8 @@ class PromptServer():
|
|||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||||
|
else:
|
||||||
|
middlewares.append(create_origin_only_middleware())
|
||||||
|
|
||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
@@ -156,6 +222,12 @@ class PromptServer():
|
|||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
|
@routes.get("/models")
|
||||||
|
def list_model_types(request):
|
||||||
|
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||||
|
|
||||||
|
return web.json_response(model_types)
|
||||||
|
|
||||||
@routes.get("/models/{folder}")
|
@routes.get("/models/{folder}")
|
||||||
async def get_models(request):
|
async def get_models(request):
|
||||||
folder = request.match_info.get("folder", None)
|
folder = request.match_info.get("folder", None)
|
||||||
@@ -418,12 +490,17 @@ class PromptServer():
|
|||||||
async def system_stats(request):
|
async def system_stats(request):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
cpu_device = comfy.model_management.torch.device("cpu")
|
||||||
|
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||||||
|
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": os.name,
|
||||||
|
"ram_total": ram_total,
|
||||||
|
"ram_free": ram_free,
|
||||||
"comfyui_version": get_comfyui_version(),
|
"comfyui_version": get_comfyui_version(),
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"pytorch_version": comfy.model_management.torch_version,
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
@@ -480,14 +557,15 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
out = {}
|
with folder_paths.cache_helper:
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
out = {}
|
||||||
try:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
out[x] = node_info(x)
|
try:
|
||||||
except Exception as e:
|
out[x] = node_info(x)
|
||||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
except Exception as e:
|
||||||
logging.error(traceback.format_exc())
|
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||||
return web.json_response(out)
|
logging.error(traceback.format_exc())
|
||||||
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/object_info/{node_class}")
|
@routes.get("/object_info/{node_class}")
|
||||||
async def get_object_info_node(request):
|
async def get_object_info_node(request):
|
||||||
@@ -601,6 +679,7 @@ class PromptServer():
|
|||||||
|
|
||||||
# Internal route. Should not be depended upon and is subject to change at any time.
|
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||||
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||||
|
# NOTE: This was an experiment and WILL BE REMOVED
|
||||||
@routes.post("/internal/models/download")
|
@routes.post("/internal/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
@@ -611,10 +690,11 @@ class PromptServer():
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
model_directory = data.get('model_directory')
|
model_directory = data.get('model_directory')
|
||||||
|
folder_path = data.get('folder_path')
|
||||||
model_filename = data.get('model_filename')
|
model_filename = data.get('model_filename')
|
||||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename:
|
if not url or not model_directory or not model_filename or not folder_path:
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
session = self.client_session
|
session = self.client_session
|
||||||
@@ -622,7 +702,7 @@ class PromptServer():
|
|||||||
logging.error("Client session is not initialized")
|
logging.error("Client session is not initialized")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return web.json_response(task.result().to_dict())
|
return web.json_response(task.result().to_dict())
|
||||||
@@ -739,6 +819,9 @@ class PromptServer():
|
|||||||
await self.send(*msg)
|
await self.send(*msg)
|
||||||
|
|
||||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||||
|
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||||
|
|
||||||
|
async def start_multi_address(self, addresses, call_on_start=None):
|
||||||
runner = web.AppRunner(self.app, access_log=None)
|
runner = web.AppRunner(self.app, access_log=None)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
ssl_ctx = None
|
ssl_ctx = None
|
||||||
@@ -749,17 +832,26 @@ class PromptServer():
|
|||||||
keyfile=args.tls_keyfile)
|
keyfile=args.tls_keyfile)
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
|
|
||||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
logging.info("Starting server\n")
|
||||||
await site.start()
|
for addr in addresses:
|
||||||
|
address = addr[0]
|
||||||
|
port = addr[1]
|
||||||
|
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
self.address = address
|
if not hasattr(self, 'address'):
|
||||||
self.port = port
|
self.address = address #TODO: remove this
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
if ':' in address:
|
||||||
|
address_print = "[{}]".format(address)
|
||||||
|
else:
|
||||||
|
address_print = address
|
||||||
|
|
||||||
|
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||||
|
|
||||||
if verbose:
|
|
||||||
logging.info("Starting server\n")
|
|
||||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
|
||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(scheme, address, port)
|
call_on_start(scheme, self.address, self.port)
|
||||||
|
|
||||||
def add_on_prompt_handler(self, handler):
|
def add_on_prompt_handler(self, handler):
|
||||||
self.on_prompt_handlers.append(handler)
|
self.on_prompt_handlers.append(handler)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
## Install test dependencies
|
## Install test dependencies
|
||||||
|
|
||||||
`pip install -r tests-units/requirements.txt`
|
`pip install -r tests-unit/requirements.txt`
|
||||||
|
|
||||||
## Run tests
|
## Run tests
|
||||||
`pytest tests-units/`
|
`pytest tests-unit/`
|
||||||
|
|||||||
66
tests-unit/comfy_test/folder_path_test.py
Normal file
66
tests-unit/comfy_test/folder_path_test.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
||||||
|
# TODO(yoland): clean up this after I get back down
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_directory_by_type():
|
||||||
|
test_dir = "/test/dir"
|
||||||
|
folder_paths.set_output_directory(test_dir)
|
||||||
|
assert folder_paths.get_directory_by_type("output") == test_dir
|
||||||
|
assert folder_paths.get_directory_by_type("invalid") is None
|
||||||
|
|
||||||
|
def test_annotated_filepath():
|
||||||
|
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
|
||||||
|
|
||||||
|
def test_get_annotated_filepath():
|
||||||
|
default_dir = "/default/dir"
|
||||||
|
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||||
|
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||||
|
|
||||||
|
def test_add_model_folder_path():
|
||||||
|
folder_paths.add_model_folder_path("test_folder", "/test/path")
|
||||||
|
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
||||||
|
|
||||||
|
def test_recursive_search(temp_dir):
|
||||||
|
os.makedirs(os.path.join(temp_dir, "subdir"))
|
||||||
|
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
||||||
|
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
|
||||||
|
|
||||||
|
files, dirs = folder_paths.recursive_search(temp_dir)
|
||||||
|
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
||||||
|
assert len(dirs) == 2 # temp_dir and subdir
|
||||||
|
|
||||||
|
def test_filter_files_extensions():
|
||||||
|
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, []) == files
|
||||||
|
|
||||||
|
@patch("folder_paths.recursive_search")
|
||||||
|
@patch("folder_paths.folder_names_and_paths")
|
||||||
|
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||||
|
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
|
||||||
|
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
||||||
|
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||||
|
|
||||||
|
def test_get_save_image_path(temp_dir):
|
||||||
|
with patch("folder_paths.output_directory", temp_dir):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
||||||
|
assert os.path.samefile(full_output_folder, temp_dir)
|
||||||
|
assert filename == "test"
|
||||||
|
assert counter == 1
|
||||||
|
assert subfolder == ""
|
||||||
|
assert filename_prefix == "test"
|
||||||
0
tests-unit/folder_paths_test/__init__.py
Normal file
0
tests-unit/folder_paths_test/__init__.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from folder_paths import filter_files_content_types
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def file_extensions():
|
||||||
|
return {
|
||||||
|
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||||
|
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||||
|
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mock_dir(file_extensions):
|
||||||
|
with tempfile.TemporaryDirectory() as directory:
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
for extension in extensions:
|
||||||
|
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
|
||||||
|
f.write(f"Sample {content_type} file in {extension} format")
|
||||||
|
yield directory
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_correctly(mock_dir, file_extensions):
|
||||||
|
files = os.listdir(mock_dir)
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
for extension in extensions:
|
||||||
|
assert f"sample_{content_type}.{extension}" in filtered_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_uniquely(mock_dir, file_extensions):
|
||||||
|
files = os.listdir(mock_dir)
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
assert len(filtered_files) == len(extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_bad_extensions():
|
||||||
|
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_no_extension():
|
||||||
|
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_no_files():
|
||||||
|
files = []
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
@@ -1,10 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import tempfile
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientResponse
|
from aiohttp import ClientResponse
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
class AsyncIteratorMock:
|
||||||
"""
|
"""
|
||||||
@@ -42,7 +49,7 @@ class ContentMock:
|
|||||||
return AsyncIteratorMock(self.chunks)
|
return AsyncIteratorMock(self.chunks)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_success():
|
async def test_download_model_success(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
@@ -53,15 +60,13 @@ async def test_download_model_success():
|
|||||||
mock_make_request = AsyncMock(return_value=mock_response)
|
mock_make_request = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
# Mock file operations
|
|
||||||
mock_open = MagicMock()
|
|
||||||
mock_file = MagicMock()
|
|
||||||
mock_open.return_value.__enter__.return_value = mock_file
|
|
||||||
time_values = itertools.count(0, 0.1)
|
time_values = itertools.count(0, 0.1)
|
||||||
|
|
||||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
patch('builtins.open', mock_open), \
|
patch('folder_paths.folder_names_and_paths', fake_paths), \
|
||||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
@@ -69,6 +74,7 @@ async def test_download_model_success():
|
|||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'checkpoints',
|
'checkpoints',
|
||||||
|
temp_dir,
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,44 +89,48 @@ async def test_download_model_success():
|
|||||||
|
|
||||||
# Check initial call
|
# Check initial call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check final call
|
# Check final call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify file writing
|
mock_file_path = os.path.join(temp_dir, 'model.sft')
|
||||||
mock_file.write.assert_any_call(b'a' * 500)
|
assert os.path.exists(mock_file_path)
|
||||||
mock_file.write.assert_any_call(b'b' * 300)
|
with open(mock_file_path, 'rb') as mock_file:
|
||||||
mock_file.write.assert_any_call(b'c' * 200)
|
assert mock_file.read() == b''.join(chunks)
|
||||||
|
os.remove(mock_file_path)
|
||||||
|
|
||||||
# Verify request was made
|
# Verify request was made
|
||||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_url_request_failure():
|
async def test_download_model_url_request_failure(temp_dir):
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
mock_response = AsyncMock(spec=ClientResponse)
|
mock_response = AsyncMock(spec=ClientResponse)
|
||||||
mock_response.status = 404 # Simulate a "Not Found" error
|
mock_response.status = 404 # Simulate a "Not Found" error
|
||||||
mock_get = AsyncMock(return_value=mock_response)
|
mock_get = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
# Mock the create_model_path function
|
# Mock the create_model_path function
|
||||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
patch('folder_paths.folder_names_and_paths', fake_paths):
|
||||||
# Call the function
|
# Call the function
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_get,
|
mock_get,
|
||||||
'model.safetensors',
|
'model.safetensors',
|
||||||
'http://example.com/model.safetensors',
|
'http://example.com/model.safetensors',
|
||||||
'mock_directory',
|
'checkpoints',
|
||||||
mock_progress_callback
|
temp_dir,
|
||||||
)
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
# Assert the expected behavior
|
# Assert the expected behavior
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
@@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
# Check that progress_callback was called with the correct arguments
|
# Check that progress_callback was called with the correct arguments
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.PENDING,
|
status=DownloadStatusType.PENDING,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
mock_progress_callback.assert_called_with(
|
mock_progress_callback.assert_called_with(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.ERROR,
|
status=DownloadStatusType.ERROR,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@@ -153,40 +163,59 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_invalid_model_subdirectory():
|
async def test_download_model_invalid_model_subdirectory():
|
||||||
|
|
||||||
mock_make_request = AsyncMock()
|
mock_make_request = AsyncMock()
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_make_request,
|
mock_make_request,
|
||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'../bad_path',
|
'../bad_path',
|
||||||
|
'../bad_path',
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
assert result.message == 'Invalid model subdirectory'
|
assert result.message.startswith('Invalid or unrecognized model directory')
|
||||||
assert result.status == 'error'
|
assert result.status == 'error'
|
||||||
assert result.already_existed is False
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_invalid_folder_path():
|
||||||
|
mock_make_request = AsyncMock()
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'checkpoints',
|
||||||
|
'invalid_path',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message.startswith("Invalid folder path")
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
# For create_model_path function
|
|
||||||
def test_create_model_path(tmp_path, monkeypatch):
|
def test_create_model_path(tmp_path, monkeypatch):
|
||||||
mock_models_dir = tmp_path / "models"
|
model_name = "model.safetensors"
|
||||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
folder_path = os.path.join(tmp_path, "mock_dir")
|
||||||
|
|
||||||
model_name = "test_model.sft"
|
file_path = create_model_path(model_name, folder_path)
|
||||||
model_directory = "test_dir"
|
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
assert file_path == os.path.join(folder_path, "model.safetensors")
|
||||||
|
|
||||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
|
||||||
assert relative_path == f"{model_directory}/{model_name}"
|
|
||||||
assert os.path.exists(os.path.dirname(file_path))
|
assert os.path.exists(os.path.dirname(file_path))
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("../path_traversal.safetensors", folder_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("/etc/some_root_path", folder_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||||
@@ -195,7 +224,7 @@ async def test_check_file_exists_when_file_exists(tmp_path):
|
|||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
@@ -203,7 +232,7 @@ async def test_check_file_exists_when_file_exists(tmp_path):
|
|||||||
assert result.already_existed is True
|
assert result.already_existed is True
|
||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
mock_callback.assert_called_once_with(
|
||||||
"test/existing_model.sft",
|
"existing_model.sft",
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -213,38 +242,46 @@ async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
|||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_callback.assert_not_called()
|
mock_callback.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_no_content_length():
|
async def test_track_download_progress_no_content_length(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {} # No Content-Length header
|
mock_response.headers = {} # No Content-Length header
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
chunks = [b'a' * 500, b'b' * 500]
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
|
||||||
|
|
||||||
with patch('builtins.open', mock_open):
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
result = await track_download_progress(
|
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
result = await track_download_progress(
|
||||||
mock_callback, 'models/model.sft', interval=0.1
|
mock_response, full_path, 'model.sft',
|
||||||
)
|
mock_callback, interval=0.1
|
||||||
|
)
|
||||||
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
assert os.path.exists(full_path)
|
||||||
|
with open(full_path, 'rb') as f:
|
||||||
|
assert f.read() == b''.join(chunks)
|
||||||
|
os.remove(full_path)
|
||||||
|
|
||||||
# Check that progress was reported even without knowing the total size
|
# Check that progress was reported even without knowing the total size
|
||||||
mock_callback.assert_any_call(
|
mock_callback.assert_any_call(
|
||||||
'models/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_interval():
|
async def test_track_download_progress_interval(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
chunks = [b'a' * 100] * 10
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
@@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
|
|||||||
mock_time = MagicMock()
|
mock_time = MagicMock()
|
||||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||||
|
|
||||||
with patch('builtins.open', mock_open), \
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
patch('time.time', mock_time):
|
|
||||||
|
with patch('time.time', mock_time):
|
||||||
await track_download_progress(
|
await track_download_progress(
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
mock_response, full_path, 'model.sft',
|
||||||
mock_callback, 'models/model.sft', interval=1.0
|
mock_callback, interval=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print out the actual call count and the arguments of each call for debugging
|
assert os.path.exists(full_path)
|
||||||
print(f"mock_callback was called {mock_callback.call_count} times")
|
with open(full_path, 'rb') as f:
|
||||||
for i, call in enumerate(mock_callback.call_args_list):
|
assert f.read() == b''.join(chunks)
|
||||||
args, kwargs = call
|
os.remove(full_path)
|
||||||
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
|
||||||
|
|
||||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||||
@@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
|
|||||||
assert last_call[0][1].status == "completed"
|
assert last_call[0][1].status == "completed"
|
||||||
assert last_call[0][1].progress_percentage == 100
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
def test_valid_subdirectory():
|
|
||||||
assert validate_model_subdirectory("valid-model123") is True
|
|
||||||
|
|
||||||
def test_subdirectory_too_long():
|
|
||||||
assert validate_model_subdirectory("a" * 51) is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_double_dots():
|
|
||||||
assert validate_model_subdirectory("model/../unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_slash():
|
|
||||||
assert validate_model_subdirectory("model/unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_special_characters():
|
|
||||||
assert validate_model_subdirectory("model@unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_underscore_and_dash():
|
|
||||||
assert validate_model_subdirectory("valid_model-name") is True
|
|
||||||
|
|
||||||
def test_empty_subdirectory():
|
|
||||||
assert validate_model_subdirectory("") is False
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("filename, expected", [
|
@pytest.mark.parametrize("filename, expected", [
|
||||||
("valid_model.safetensors", True),
|
("valid_model.safetensors", True),
|
||||||
("valid_model.sft", True),
|
("valid_model.sft", True),
|
||||||
|
|||||||
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from aiohttp import web
|
||||||
|
from app.user_manager import UserManager
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
pytestmark = (
|
||||||
|
pytest.mark.asyncio
|
||||||
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_manager(tmp_path):
|
||||||
|
um = UserManager()
|
||||||
|
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||||
|
tmp_path, file
|
||||||
|
)
|
||||||
|
return um
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(user_manager):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
user_manager.add_routes(routes)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == ["file1.txt"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&full_info=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["path"] == "file1.txt"
|
||||||
|
assert "size" in result[0]
|
||||||
|
assert "modified" in result[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == [
|
||||||
|
["subdir/file1.txt", "subdir", "file1.txt"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=")
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||||
|
os_sep = "\\"
|
||||||
|
with patch("os.sep", os_sep):
|
||||||
|
with patch("os.path.sep", os_sep):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0] # Ensure backslash is not present
|
||||||
|
assert result[0] == "subdir/file1.txt"
|
||||||
|
|
||||||
|
# Test with full_info
|
||||||
|
resp = await client.get(
|
||||||
|
"/userdata?dir=test_dir&recurse=true&full_info=true"
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||||
|
assert result[0]["path"] == "subdir/file1.txt"
|
||||||
126
tests-unit/utils/extra_config_test.py
Normal file
126
tests-unit/utils/extra_config_test.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch, mock_open
|
||||||
|
|
||||||
|
from utils.extra_config import load_extra_path_config
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_content():
|
||||||
|
return {
|
||||||
|
'test_config': {
|
||||||
|
'base_path': '~/App/',
|
||||||
|
'checkpoints': 'subfolder1',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expanded_home():
|
||||||
|
return '/home/user'
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yaml_config_with_appdata():
|
||||||
|
return """
|
||||||
|
test_config:
|
||||||
|
base_path: '%APPDATA%/ComfyUI'
|
||||||
|
checkpoints: 'models/checkpoints'
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||||
|
return yaml.safe_load(yaml_config_with_appdata)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expandvars_appdata():
|
||||||
|
mock = Mock()
|
||||||
|
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||||
|
return mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_add_model_folder_path():
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expanduser(mock_expanded_home):
|
||||||
|
def _expanduser(path):
|
||||||
|
if path.startswith('~/'):
|
||||||
|
return os.path.join(mock_expanded_home, path[2:])
|
||||||
|
return path
|
||||||
|
return _expanduser
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_safe_load(mock_yaml_content):
|
||||||
|
return Mock(return_value=mock_yaml_content)
|
||||||
|
|
||||||
|
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||||
|
def test_load_extra_model_paths_expands_userpath(
|
||||||
|
mock_file,
|
||||||
|
monkeypatch,
|
||||||
|
mock_add_model_folder_path,
|
||||||
|
mock_expanduser,
|
||||||
|
mock_yaml_safe_load,
|
||||||
|
mock_expanded_home
|
||||||
|
):
|
||||||
|
# Attach mocks used by load_extra_path_config
|
||||||
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
|
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
|
||||||
|
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
|
||||||
|
|
||||||
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
|
expected_calls = [
|
||||||
|
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
|
# Check if add_model_folder_path was called with the correct arguments
|
||||||
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
|
assert actual_call.args[0] == expected_call[0]
|
||||||
|
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
|
||||||
|
assert actual_call.args[2] == expected_call[2]
|
||||||
|
|
||||||
|
# Check if yaml.safe_load was called
|
||||||
|
mock_yaml_safe_load.assert_called_once()
|
||||||
|
|
||||||
|
# Check if open was called with the correct file path
|
||||||
|
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||||
|
|
||||||
|
@patch('builtins.open', new_callable=mock_open)
|
||||||
|
def test_load_extra_model_paths_expands_appdata(
|
||||||
|
mock_file,
|
||||||
|
monkeypatch,
|
||||||
|
mock_add_model_folder_path,
|
||||||
|
mock_expandvars_appdata,
|
||||||
|
yaml_config_with_appdata,
|
||||||
|
mock_yaml_content_appdata
|
||||||
|
):
|
||||||
|
# Set the mock_file to return yaml with appdata as a variable
|
||||||
|
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
||||||
|
|
||||||
|
# Attach mocks
|
||||||
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
|
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
|
||||||
|
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
|
||||||
|
|
||||||
|
# Mock expanduser to do nothing (since we're not testing it here)
|
||||||
|
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
|
||||||
|
|
||||||
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
|
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||||
|
expected_calls = [
|
||||||
|
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
|
# Check the base path variable was expanded
|
||||||
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
|
assert actual_call.args == expected_call
|
||||||
|
|
||||||
|
# Verify that expandvars was called
|
||||||
|
assert mock_expandvars_appdata.called
|
||||||
@@ -496,3 +496,29 @@ class TestExecution:
|
|||||||
assert len(images) == 1, "Should have 1 image"
|
assert len(images) == 1, "Should have 1 image"
|
||||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|
||||||
|
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||||
|
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||||
|
# only that one entry in the list is blocked.
|
||||||
|
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
|
||||||
|
int1 = g.node("StubInt", value=1)
|
||||||
|
int2 = g.node("StubInt", value=2)
|
||||||
|
int3 = g.node("StubInt", value=3)
|
||||||
|
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||||
|
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||||
|
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||||
|
|
||||||
|
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||||
|
output = g.node("PreviewImage", images=list_output.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
assert result.did_run(output), "The execution should have run"
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 2, "Should have 2 images"
|
||||||
|
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||||
|
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||||
|
|||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
28
utils/extra_config.py
Normal file
28
utils/extra_config.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import folder_paths
|
||||||
|
import logging
|
||||||
|
|
||||||
|
def load_extra_path_config(yaml_path):
|
||||||
|
with open(yaml_path, 'r') as stream:
|
||||||
|
config = yaml.safe_load(stream)
|
||||||
|
for c in config:
|
||||||
|
conf = config[c]
|
||||||
|
if conf is None:
|
||||||
|
continue
|
||||||
|
base_path = None
|
||||||
|
if "base_path" in conf:
|
||||||
|
base_path = conf.pop("base_path")
|
||||||
|
base_path = os.path.expandvars(os.path.expanduser(base_path))
|
||||||
|
is_default = False
|
||||||
|
if "is_default" in conf:
|
||||||
|
is_default = conf.pop("is_default")
|
||||||
|
for x in conf:
|
||||||
|
for y in conf[x].split("\n"):
|
||||||
|
if len(y) == 0:
|
||||||
|
continue
|
||||||
|
full_path = y
|
||||||
|
if base_path is not None:
|
||||||
|
full_path = os.path.join(base_path, full_path)
|
||||||
|
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||||
|
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||||
1
web/assets/CREDIT.txt
generated
vendored
Normal file
1
web/assets/CREDIT.txt
generated
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.
|
||||||
103
web/assets/ExtensionPanel-BmKi_NKS.js
generated
vendored
Normal file
103
web/assets/ExtensionPanel-BmKi_NKS.js
generated
vendored
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { d as defineComponent, bQ as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bR as script$1, A as createBaseVNode, x as createBlock, N as Fragment, O as renderList, a6 as toDisplayString, aw as createTextVNode, j as createCommentVNode, D as script$4 } from "./index-BHayQCxv.js";
|
||||||
|
import { s as script, a as script$2, b as script$3 } from "./index-CwRXxFdA.js";
|
||||||
|
import "./index-C_wOqB0f.js";
|
||||||
|
const _hoisted_1 = { class: "extension-panel" };
|
||||||
|
const _hoisted_2 = { class: "mt-4" };
|
||||||
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "ExtensionPanel",
|
||||||
|
setup(__props) {
|
||||||
|
const extensionStore = useExtensionStore();
|
||||||
|
const settingStore = useSettingStore();
|
||||||
|
const editingEnabledExtensions = ref({});
|
||||||
|
onMounted(() => {
|
||||||
|
extensionStore.extensions.forEach((ext) => {
|
||||||
|
editingEnabledExtensions.value[ext.name] = extensionStore.isExtensionEnabled(ext.name);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
const changedExtensions = computed(() => {
|
||||||
|
return extensionStore.extensions.filter(
|
||||||
|
(ext) => editingEnabledExtensions.value[ext.name] !== extensionStore.isExtensionEnabled(ext.name)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
const hasChanges = computed(() => {
|
||||||
|
return changedExtensions.value.length > 0;
|
||||||
|
});
|
||||||
|
const updateExtensionStatus = /* @__PURE__ */ __name(() => {
|
||||||
|
const editingDisabledExtensionNames = Object.entries(
|
||||||
|
editingEnabledExtensions.value
|
||||||
|
).filter(([_, enabled]) => !enabled).map(([name]) => name);
|
||||||
|
settingStore.set("Comfy.Extension.Disabled", [
|
||||||
|
...extensionStore.inactiveDisabledExtensionNames,
|
||||||
|
...editingDisabledExtensionNames
|
||||||
|
]);
|
||||||
|
}, "updateExtensionStatus");
|
||||||
|
const applyChanges = /* @__PURE__ */ __name(() => {
|
||||||
|
window.location.reload();
|
||||||
|
}, "applyChanges");
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||||
|
createVNode(unref(script$2), {
|
||||||
|
value: unref(extensionStore).extensions,
|
||||||
|
stripedRows: "",
|
||||||
|
size: "small"
|
||||||
|
}, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createVNode(unref(script), {
|
||||||
|
field: "name",
|
||||||
|
header: _ctx.$t("extensionName"),
|
||||||
|
sortable: ""
|
||||||
|
}, null, 8, ["header"]),
|
||||||
|
createVNode(unref(script), { pt: {
|
||||||
|
bodyCell: "flex items-center justify-end"
|
||||||
|
} }, {
|
||||||
|
body: withCtx((slotProps) => [
|
||||||
|
createVNode(unref(script$1), {
|
||||||
|
modelValue: editingEnabledExtensions.value[slotProps.data.name],
|
||||||
|
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
|
||||||
|
onChange: updateExtensionStatus
|
||||||
|
}, null, 8, ["modelValue", "onUpdate:modelValue"])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
})
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
}, 8, ["value"]),
|
||||||
|
createBaseVNode("div", _hoisted_2, [
|
||||||
|
hasChanges.value ? (openBlock(), createBlock(unref(script$3), {
|
||||||
|
key: 0,
|
||||||
|
severity: "info"
|
||||||
|
}, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createBaseVNode("ul", null, [
|
||||||
|
(openBlock(true), createElementBlock(Fragment, null, renderList(changedExtensions.value, (ext) => {
|
||||||
|
return openBlock(), createElementBlock("li", {
|
||||||
|
key: ext.name
|
||||||
|
}, [
|
||||||
|
createBaseVNode("span", null, toDisplayString(unref(extensionStore).isExtensionEnabled(ext.name) ? "[-]" : "[+]"), 1),
|
||||||
|
createTextVNode(" " + toDisplayString(ext.name), 1)
|
||||||
|
]);
|
||||||
|
}), 128))
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
})) : createCommentVNode("", true),
|
||||||
|
createVNode(unref(script$4), {
|
||||||
|
label: _ctx.$t("reloadToApplyChanges"),
|
||||||
|
icon: "pi pi-refresh",
|
||||||
|
onClick: applyChanges,
|
||||||
|
disabled: !hasChanges.value,
|
||||||
|
text: "",
|
||||||
|
fluid: "",
|
||||||
|
severity: "danger"
|
||||||
|
}, null, 8, ["label", "disabled"])
|
||||||
|
])
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
export {
|
||||||
|
_sfc_main as default
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=ExtensionPanel-BmKi_NKS.js.map
|
||||||
1
web/assets/ExtensionPanel-BmKi_NKS.js.map
generated
vendored
Normal file
1
web/assets/ExtensionPanel-BmKi_NKS.js.map
generated
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"version":3,"file":"ExtensionPanel-BmKi_NKS.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <div class=\"extension-panel\">\n <DataTable :value=\"extensionStore.extensions\" stripedRows size=\"small\">\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n <div class=\"mt-4\">\n <Message v-if=\"hasChanges\" severity=\"info\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n </Message>\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n icon=\"pi pi-refresh\"\n @click=\"applyChanges\"\n :disabled=\"!hasChanges\"\n text\n fluid\n severity=\"danger\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;AAmDA,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||||
7986
web/assets/GraphView-C4blCugc.js
generated
vendored
Normal file
7986
web/assets/GraphView-C4blCugc.js
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/GraphView-C4blCugc.js.map
generated
vendored
Normal file
1
web/assets/GraphView-C4blCugc.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
281
web/assets/GraphView-Cf7ubG48.css
generated
vendored
Normal file
281
web/assets/GraphView-Cf7ubG48.css
generated
vendored
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
|
||||||
|
.group-title-editor.node-title-editor[data-v-8a100d5a] {
|
||||||
|
z-index: 9999;
|
||||||
|
padding: 0.25rem;
|
||||||
|
}
|
||||||
|
[data-v-8a100d5a] .editable-text {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
[data-v-8a100d5a] .editable-text input {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
/* Override the default font size */
|
||||||
|
font-size: inherit;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-bar-button-icon {
|
||||||
|
font-size: var(--sidebar-icon-size) !important;
|
||||||
|
}
|
||||||
|
.side-bar-button-selected .side-bar-button-icon {
|
||||||
|
font-size: var(--sidebar-icon-size) !important;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-bar-button[data-v-caa3ee9c] {
|
||||||
|
width: var(--sidebar-width);
|
||||||
|
height: var(--sidebar-width);
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
||||||
|
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
||||||
|
border-left: 4px solid var(--p-button-text-primary-color);
|
||||||
|
}
|
||||||
|
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
||||||
|
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
||||||
|
border-right: 4px solid var(--p-button-text-primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
:root {
|
||||||
|
--sidebar-width: 64px;
|
||||||
|
--sidebar-icon-size: 1.5rem;
|
||||||
|
}
|
||||||
|
:root .small-sidebar {
|
||||||
|
--sidebar-width: 40px;
|
||||||
|
--sidebar-icon-size: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-tool-bar-container[data-v-37fd2fa4] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
|
||||||
|
pointer-events: auto;
|
||||||
|
|
||||||
|
width: var(--sidebar-width);
|
||||||
|
height: 100%;
|
||||||
|
|
||||||
|
background-color: var(--comfy-menu-bg);
|
||||||
|
color: var(--fg-color);
|
||||||
|
}
|
||||||
|
.side-tool-bar-end[data-v-37fd2fa4] {
|
||||||
|
align-self: flex-end;
|
||||||
|
margin-top: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-b49f20b1] .p-splitter-gutter {
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.side-bar-panel[data-v-b49f20b1] {
|
||||||
|
background-color: var(--bg-color);
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.bottom-panel[data-v-b49f20b1] {
|
||||||
|
background-color: var(--bg-color);
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.splitter-overlay[data-v-b49f20b1] {
|
||||||
|
pointer-events: none;
|
||||||
|
border-style: none;
|
||||||
|
background-color: transparent;
|
||||||
|
}
|
||||||
|
.splitter-overlay-root[data-v-b49f20b1] {
|
||||||
|
position: absolute;
|
||||||
|
top: 0px;
|
||||||
|
left: 0px;
|
||||||
|
height: 100%;
|
||||||
|
width: 100%;
|
||||||
|
|
||||||
|
/* Set it the same as the ComfyUI menu */
|
||||||
|
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
|
||||||
|
999 should be sufficient to make sure splitter overlays on node's DOM
|
||||||
|
widgets */
|
||||||
|
z-index: 999;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-37f672ab] .highlight {
|
||||||
|
background-color: var(--p-primary-color);
|
||||||
|
color: var(--p-primary-contrast-color);
|
||||||
|
font-weight: bold;
|
||||||
|
border-radius: 0.25rem;
|
||||||
|
padding: 0rem 0.125rem;
|
||||||
|
margin: -0.125rem 0.125rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-vue-node-search-container[data-v-2d409367] {
|
||||||
|
display: flex;
|
||||||
|
width: 100%;
|
||||||
|
min-width: 26rem;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-search-container[data-v-2d409367] * {
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-preview-container[data-v-2d409367] {
|
||||||
|
position: absolute;
|
||||||
|
left: -350px;
|
||||||
|
top: 50px;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-search-box[data-v-2d409367] {
|
||||||
|
z-index: 10;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
._filter-button[data-v-2d409367] {
|
||||||
|
z-index: 10;
|
||||||
|
}
|
||||||
|
._dialog[data-v-2d409367] {
|
||||||
|
min-width: 26rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.invisible-dialog-root {
|
||||||
|
width: 60%;
|
||||||
|
min-width: 24rem;
|
||||||
|
max-width: 48rem;
|
||||||
|
border: 0 !important;
|
||||||
|
background-color: transparent !important;
|
||||||
|
margin-top: 25vh;
|
||||||
|
margin-left: 400px;
|
||||||
|
}
|
||||||
|
@media all and (max-width: 768px) {
|
||||||
|
.invisible-dialog-root {
|
||||||
|
margin-left: 0px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.node-search-box-dialog-mask {
|
||||||
|
align-items: flex-start !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-tooltip[data-v-79ec8c53] {
|
||||||
|
background: var(--comfy-input-bg);
|
||||||
|
border-radius: 5px;
|
||||||
|
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||||
|
color: var(--input-text);
|
||||||
|
font-family: sans-serif;
|
||||||
|
left: 0;
|
||||||
|
max-width: 30vw;
|
||||||
|
padding: 4px 8px;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
transform: translate(5px, calc(-100% - 5px));
|
||||||
|
white-space: pre-wrap;
|
||||||
|
z-index: 99999;
|
||||||
|
}
|
||||||
|
|
||||||
|
.p-buttongroup-vertical[data-v-444d3768] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
border-radius: var(--p-button-border-radius);
|
||||||
|
overflow: hidden;
|
||||||
|
border: 1px solid var(--p-panel-border-color);
|
||||||
|
}
|
||||||
|
.p-buttongroup-vertical .p-button[data-v-444d3768] {
|
||||||
|
margin: 0;
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-84e785b8] .p-togglebutton::before {
|
||||||
|
display: none
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton {
|
||||||
|
position: relative;
|
||||||
|
flex-shrink: 0;
|
||||||
|
border-radius: 0px;
|
||||||
|
background-color: transparent;
|
||||||
|
padding-left: 0.5rem;
|
||||||
|
padding-right: 0.5rem
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
|
||||||
|
border-bottom-width: 2px;
|
||||||
|
border-bottom-color: var(--p-button-text-primary-color)
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
|
||||||
|
visibility: visible
|
||||||
|
}
|
||||||
|
.status-indicator[data-v-84e785b8] {
|
||||||
|
position: absolute;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
top: 50%;
|
||||||
|
left: 50%;
|
||||||
|
transform: translate(-50%, -50%)
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
|
||||||
|
display: none
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton .close-button {
|
||||||
|
visibility: hidden
|
||||||
|
}
|
||||||
|
|
||||||
|
.top-menubar[data-v-9646ca0a] .p-menubar-item-link svg {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
[data-v-9646ca0a] .p-menubar-submenu.dropdown-direction-up {
|
||||||
|
top: auto;
|
||||||
|
bottom: 100%;
|
||||||
|
flex-direction: column-reverse;
|
||||||
|
}
|
||||||
|
.keybinding-tag[data-v-9646ca0a] {
|
||||||
|
background: var(--p-content-hover-background);
|
||||||
|
border-color: var(--p-content-border-color);
|
||||||
|
border-style: solid;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-713442be] .p-inputtext {
|
||||||
|
border-top-left-radius: 0;
|
||||||
|
border-bottom-left-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfyui-queue-button[data-v-2b80bf74] .p-splitbutton-dropdown {
|
||||||
|
border-top-right-radius: 0;
|
||||||
|
border-bottom-right-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.actionbar[data-v-2e54db00] {
|
||||||
|
pointer-events: all;
|
||||||
|
position: fixed;
|
||||||
|
z-index: 1000;
|
||||||
|
}
|
||||||
|
.actionbar.is-docked[data-v-2e54db00] {
|
||||||
|
position: static;
|
||||||
|
border-style: none;
|
||||||
|
background-color: transparent;
|
||||||
|
padding: 0px;
|
||||||
|
}
|
||||||
|
.actionbar.is-dragging[data-v-2e54db00] {
|
||||||
|
-webkit-user-select: none;
|
||||||
|
-moz-user-select: none;
|
||||||
|
user-select: none;
|
||||||
|
}
|
||||||
|
[data-v-2e54db00] .p-panel-content {
|
||||||
|
padding: 0.25rem;
|
||||||
|
}
|
||||||
|
[data-v-2e54db00] .p-panel-header {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfyui-menu[data-v-ad2c662b] {
|
||||||
|
width: 100vw;
|
||||||
|
background: var(--comfy-menu-bg);
|
||||||
|
color: var(--fg-color);
|
||||||
|
font-family: Arial, Helvetica, sans-serif;
|
||||||
|
font-size: 0.8em;
|
||||||
|
box-sizing: border-box;
|
||||||
|
z-index: 1000;
|
||||||
|
order: 0;
|
||||||
|
grid-column: 1/-1;
|
||||||
|
max-height: 90vh;
|
||||||
|
}
|
||||||
|
.comfyui-menu.dropzone[data-v-ad2c662b] {
|
||||||
|
background: var(--p-highlight-background);
|
||||||
|
}
|
||||||
|
.comfyui-menu.dropzone-active[data-v-ad2c662b] {
|
||||||
|
background: var(--p-highlight-background-focus);
|
||||||
|
}
|
||||||
|
.comfyui-logo[data-v-ad2c662b] {
|
||||||
|
font-size: 1.2em;
|
||||||
|
-webkit-user-select: none;
|
||||||
|
-moz-user-select: none;
|
||||||
|
user-select: none;
|
||||||
|
cursor: default;
|
||||||
|
}
|
||||||
8
web/assets/KeybindingPanel-BNYKhW1k.css
generated
vendored
Normal file
8
web/assets/KeybindingPanel-BNYKhW1k.css
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
[data-v-e5724e4d] .p-datatable-tbody > tr > td {
|
||||||
|
padding: 1px;
|
||||||
|
min-height: 2rem;
|
||||||
|
}
|
||||||
|
[data-v-e5724e4d] .p-datatable-row-selected .actions,[data-v-e5724e4d] .p-datatable-selectable-row:hover .actions {
|
||||||
|
visibility: visible;
|
||||||
|
}
|
||||||
264
web/assets/KeybindingPanel-Dm_3sBT5.js
generated
vendored
Normal file
264
web/assets/KeybindingPanel-Dm_3sBT5.js
generated
vendored
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, aw as createTextVNode, a6 as toDisplayString, z as unref, aA as script, j as createCommentVNode, r as ref, bN as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aJ as watchEffect, b9 as useToast, t as resolveDirective, bO as SearchBox, A as createBaseVNode, D as script$2, x as createBlock, ao as script$4, be as withModifiers, aH as script$6, v as withDirectives, P as pushScopeId, Q as popScopeId, bJ as KeyComboImpl, bP as KeybindingImpl, _ as _export_sfc } from "./index-BHayQCxv.js";
|
||||||
|
import { s as script$1, a as script$3, b as script$5 } from "./index-CwRXxFdA.js";
|
||||||
|
import "./index-C_wOqB0f.js";
|
||||||
|
const _hoisted_1$1 = {
|
||||||
|
key: 0,
|
||||||
|
class: "px-2"
|
||||||
|
};
|
||||||
|
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "KeyComboDisplay",
|
||||||
|
props: {
|
||||||
|
keyCombo: {},
|
||||||
|
isModified: { type: Boolean, default: false }
|
||||||
|
},
|
||||||
|
setup(__props) {
|
||||||
|
const props = __props;
|
||||||
|
const keySequences = computed(() => props.keyCombo.getKeySequences());
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
return openBlock(), createElementBlock("span", null, [
|
||||||
|
(openBlock(true), createElementBlock(Fragment, null, renderList(keySequences.value, (sequence, index) => {
|
||||||
|
return openBlock(), createElementBlock(Fragment, { key: index }, [
|
||||||
|
createVNode(unref(script), {
|
||||||
|
severity: _ctx.isModified ? "info" : "secondary"
|
||||||
|
}, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createTextVNode(toDisplayString(sequence), 1)
|
||||||
|
]),
|
||||||
|
_: 2
|
||||||
|
}, 1032, ["severity"]),
|
||||||
|
index < keySequences.value.length - 1 ? (openBlock(), createElementBlock("span", _hoisted_1$1, "+")) : createCommentVNode("", true)
|
||||||
|
], 64);
|
||||||
|
}), 128))
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-e5724e4d"), n = n(), popScopeId(), n), "_withScopeId");
|
||||||
|
const _hoisted_1 = { class: "keybinding-panel" };
|
||||||
|
const _hoisted_2 = { class: "actions invisible" };
|
||||||
|
const _hoisted_3 = { key: 1 };
|
||||||
|
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||||
|
__name: "KeybindingPanel",
|
||||||
|
setup(__props) {
|
||||||
|
const filters = ref({
|
||||||
|
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
|
||||||
|
});
|
||||||
|
const keybindingStore = useKeybindingStore();
|
||||||
|
const commandStore = useCommandStore();
|
||||||
|
const commandsData = computed(() => {
|
||||||
|
return Object.values(commandStore.commands).map((command) => ({
|
||||||
|
id: command.id,
|
||||||
|
keybinding: keybindingStore.getKeybindingByCommandId(command.id)
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
const selectedCommandData = ref(null);
|
||||||
|
const editDialogVisible = ref(false);
|
||||||
|
const newBindingKeyCombo = ref(null);
|
||||||
|
const currentEditingCommand = ref(null);
|
||||||
|
const keybindingInput = ref(null);
|
||||||
|
const existingKeybindingOnCombo = computed(() => {
|
||||||
|
if (!currentEditingCommand.value) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (currentEditingCommand.value.keybinding?.combo?.equals(
|
||||||
|
newBindingKeyCombo.value
|
||||||
|
)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (!newBindingKeyCombo.value) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return keybindingStore.getKeybinding(newBindingKeyCombo.value);
|
||||||
|
});
|
||||||
|
function editKeybinding(commandData) {
|
||||||
|
currentEditingCommand.value = commandData;
|
||||||
|
newBindingKeyCombo.value = commandData.keybinding ? commandData.keybinding.combo : null;
|
||||||
|
editDialogVisible.value = true;
|
||||||
|
}
|
||||||
|
__name(editKeybinding, "editKeybinding");
|
||||||
|
watchEffect(() => {
|
||||||
|
if (editDialogVisible.value) {
|
||||||
|
setTimeout(() => {
|
||||||
|
keybindingInput.value?.$el?.focus();
|
||||||
|
}, 300);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
function removeKeybinding(commandData) {
|
||||||
|
if (commandData.keybinding) {
|
||||||
|
keybindingStore.unsetKeybinding(commandData.keybinding);
|
||||||
|
keybindingStore.persistUserKeybindings();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__name(removeKeybinding, "removeKeybinding");
|
||||||
|
function captureKeybinding(event) {
|
||||||
|
const keyCombo = KeyComboImpl.fromEvent(event);
|
||||||
|
newBindingKeyCombo.value = keyCombo;
|
||||||
|
}
|
||||||
|
__name(captureKeybinding, "captureKeybinding");
|
||||||
|
function cancelEdit() {
|
||||||
|
editDialogVisible.value = false;
|
||||||
|
currentEditingCommand.value = null;
|
||||||
|
newBindingKeyCombo.value = null;
|
||||||
|
}
|
||||||
|
__name(cancelEdit, "cancelEdit");
|
||||||
|
function saveKeybinding() {
|
||||||
|
if (currentEditingCommand.value && newBindingKeyCombo.value) {
|
||||||
|
const updated = keybindingStore.updateKeybindingOnCommand(
|
||||||
|
new KeybindingImpl({
|
||||||
|
commandId: currentEditingCommand.value.id,
|
||||||
|
combo: newBindingKeyCombo.value
|
||||||
|
})
|
||||||
|
);
|
||||||
|
if (updated) {
|
||||||
|
keybindingStore.persistUserKeybindings();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cancelEdit();
|
||||||
|
}
|
||||||
|
__name(saveKeybinding, "saveKeybinding");
|
||||||
|
const toast = useToast();
|
||||||
|
async function resetKeybindings() {
|
||||||
|
keybindingStore.resetKeybindings();
|
||||||
|
await keybindingStore.persistUserKeybindings();
|
||||||
|
toast.add({
|
||||||
|
severity: "info",
|
||||||
|
summary: "Info",
|
||||||
|
detail: "Keybindings reset",
|
||||||
|
life: 3e3
|
||||||
|
});
|
||||||
|
}
|
||||||
|
__name(resetKeybindings, "resetKeybindings");
|
||||||
|
return (_ctx, _cache) => {
|
||||||
|
const _directive_tooltip = resolveDirective("tooltip");
|
||||||
|
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||||
|
createVNode(unref(script$3), {
|
||||||
|
value: commandsData.value,
|
||||||
|
selection: selectedCommandData.value,
|
||||||
|
"onUpdate:selection": _cache[1] || (_cache[1] = ($event) => selectedCommandData.value = $event),
|
||||||
|
"global-filter-fields": ["id"],
|
||||||
|
filters: filters.value,
|
||||||
|
selectionMode: "single",
|
||||||
|
stripedRows: "",
|
||||||
|
pt: {
|
||||||
|
header: "px-0"
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
header: withCtx(() => [
|
||||||
|
createVNode(SearchBox, {
|
||||||
|
modelValue: filters.value["global"].value,
|
||||||
|
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
|
||||||
|
placeholder: _ctx.$t("searchKeybindings") + "..."
|
||||||
|
}, null, 8, ["modelValue", "placeholder"])
|
||||||
|
]),
|
||||||
|
default: withCtx(() => [
|
||||||
|
createVNode(unref(script$1), {
|
||||||
|
field: "actions",
|
||||||
|
header: ""
|
||||||
|
}, {
|
||||||
|
body: withCtx((slotProps) => [
|
||||||
|
createBaseVNode("div", _hoisted_2, [
|
||||||
|
createVNode(unref(script$2), {
|
||||||
|
icon: "pi pi-pencil",
|
||||||
|
class: "p-button-text",
|
||||||
|
onClick: /* @__PURE__ */ __name(($event) => editKeybinding(slotProps.data), "onClick")
|
||||||
|
}, null, 8, ["onClick"]),
|
||||||
|
createVNode(unref(script$2), {
|
||||||
|
icon: "pi pi-trash",
|
||||||
|
class: "p-button-text p-button-danger",
|
||||||
|
onClick: /* @__PURE__ */ __name(($event) => removeKeybinding(slotProps.data), "onClick"),
|
||||||
|
disabled: !slotProps.data.keybinding
|
||||||
|
}, null, 8, ["onClick", "disabled"])
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
}),
|
||||||
|
createVNode(unref(script$1), {
|
||||||
|
field: "id",
|
||||||
|
header: "Command ID",
|
||||||
|
sortable: ""
|
||||||
|
}),
|
||||||
|
createVNode(unref(script$1), {
|
||||||
|
field: "keybinding",
|
||||||
|
header: "Keybinding"
|
||||||
|
}, {
|
||||||
|
body: withCtx((slotProps) => [
|
||||||
|
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
|
||||||
|
key: 0,
|
||||||
|
keyCombo: slotProps.data.keybinding.combo,
|
||||||
|
isModified: unref(keybindingStore).isCommandKeybindingModified(slotProps.data.id)
|
||||||
|
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
})
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
}, 8, ["value", "selection", "filters"]),
|
||||||
|
createVNode(unref(script$6), {
|
||||||
|
class: "min-w-96",
|
||||||
|
visible: editDialogVisible.value,
|
||||||
|
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => editDialogVisible.value = $event),
|
||||||
|
modal: "",
|
||||||
|
header: currentEditingCommand.value?.id,
|
||||||
|
onHide: cancelEdit
|
||||||
|
}, {
|
||||||
|
footer: withCtx(() => [
|
||||||
|
createVNode(unref(script$2), {
|
||||||
|
label: "Save",
|
||||||
|
icon: "pi pi-check",
|
||||||
|
onClick: saveKeybinding,
|
||||||
|
disabled: !!existingKeybindingOnCombo.value,
|
||||||
|
autofocus: ""
|
||||||
|
}, null, 8, ["disabled"])
|
||||||
|
]),
|
||||||
|
default: withCtx(() => [
|
||||||
|
createBaseVNode("div", null, [
|
||||||
|
createVNode(unref(script$4), {
|
||||||
|
class: "mb-2 text-center",
|
||||||
|
ref_key: "keybindingInput",
|
||||||
|
ref: keybindingInput,
|
||||||
|
modelValue: newBindingKeyCombo.value?.toString() ?? "",
|
||||||
|
placeholder: "Press keys for new binding",
|
||||||
|
onKeydown: withModifiers(captureKeybinding, ["stop", "prevent"]),
|
||||||
|
autocomplete: "off",
|
||||||
|
fluid: "",
|
||||||
|
invalid: !!existingKeybindingOnCombo.value
|
||||||
|
}, null, 8, ["modelValue", "invalid"]),
|
||||||
|
existingKeybindingOnCombo.value ? (openBlock(), createBlock(unref(script$5), {
|
||||||
|
key: 0,
|
||||||
|
severity: "error"
|
||||||
|
}, {
|
||||||
|
default: withCtx(() => [
|
||||||
|
createTextVNode(" Keybinding already exists on "),
|
||||||
|
createVNode(unref(script), {
|
||||||
|
severity: "secondary",
|
||||||
|
value: existingKeybindingOnCombo.value.commandId
|
||||||
|
}, null, 8, ["value"])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
})) : createCommentVNode("", true)
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
_: 1
|
||||||
|
}, 8, ["visible", "header"]),
|
||||||
|
withDirectives(createVNode(unref(script$2), {
|
||||||
|
class: "mt-4",
|
||||||
|
label: _ctx.$t("reset"),
|
||||||
|
icon: "pi pi-trash",
|
||||||
|
severity: "danger",
|
||||||
|
fluid: "",
|
||||||
|
text: "",
|
||||||
|
onClick: resetKeybindings
|
||||||
|
}, null, 8, ["label"]), [
|
||||||
|
[_directive_tooltip, _ctx.$t("resetKeybindingsTooltip")]
|
||||||
|
])
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-e5724e4d"]]);
|
||||||
|
export {
|
||||||
|
KeybindingPanel as default
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=KeybindingPanel-Dm_3sBT5.js.map
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user