Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 | ||
|
|
f1c2301697 | ||
|
|
8d31a6632f | ||
|
|
b643eae08b | ||
|
|
baa6b4dc36 | ||
|
|
d4aeefc297 | ||
|
|
587e7ca654 | ||
|
|
c90459eba0 | ||
|
|
04278afb10 | ||
|
|
935ae153e1 | ||
|
|
e91662e784 | ||
|
|
63fafaef45 | ||
|
|
ec28cd9136 | ||
|
|
6eb5d64522 | ||
|
|
10a79e9898 | ||
|
|
ea3f39bd69 | ||
|
|
b33cd61070 | ||
|
|
34eda0f853 | ||
|
|
d31e226650 | ||
|
|
b79fd7d92c | ||
|
|
38c22e631a | ||
|
|
6bbdcd28ae | ||
|
|
ab130001a8 | ||
|
|
ca4b8f30e0 | ||
|
|
70b84058c1 | ||
|
|
2ca8f6e23d | ||
|
|
7985ff88b9 | ||
|
|
c6812947e9 | ||
|
|
9230f65823 | ||
|
|
6ab1e6fd4a | ||
|
|
07dcbc3a3e | ||
|
|
8ae23d8e80 | ||
|
|
7df42b9a23 | ||
|
|
5d8bbb7281 | ||
|
|
2c1d2375d6 | ||
|
|
64ccb3c7e3 | ||
|
|
9465b23432 | ||
|
|
bb4416dd5b | ||
|
|
c0b0da264b | ||
|
|
c26ca27207 | ||
|
|
7c6bb84016 | ||
|
|
c54d3ed5e6 | ||
|
|
c7ee4b37a1 | ||
|
|
7b70b266d8 | ||
|
|
8f60d093ba | ||
|
|
dafbe321d2 | ||
|
|
5f84ea63e8 | ||
|
|
843a7ff70c | ||
|
|
a60620dcea | ||
|
|
015f73dc49 |
@@ -14,7 +14,7 @@ run_cpu.bat
|
|||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED WAY TO UPDATE:
|
RECOMMENDED WAY TO UPDATE:
|
||||||
|
|||||||
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
||||||
|
pause
|
||||||
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: 'Close stale issues'
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
# Run daily at 430 am PT
|
||||||
|
- cron: '30 11 * * *'
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||||
|
days-before-stale: 30
|
||||||
|
days-before-close: 7
|
||||||
|
stale-issue-label: 'Stale'
|
||||||
|
only-labels: 'User Support'
|
||||||
|
exempt-all-assignees: true
|
||||||
|
exempt-all-milestones: true
|
||||||
@@ -67,6 +67,7 @@ jobs:
|
|||||||
mkdir update
|
mkdir update
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
|
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||||
|
|
||||||
echo "call update_comfyui.bat nopause
|
echo "call update_comfyui.bat nopause
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
|||||||
12
README.md
12
README.md
@@ -1,7 +1,7 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
# ComfyUI
|
# ComfyUI
|
||||||
**The most powerful and modular stable diffusion GUI and backend.**
|
**The most powerful and modular diffusion model GUI and backend.**
|
||||||
|
|
||||||
|
|
||||||
[![Website][website-shield]][website-url]
|
[![Website][website-shield]][website-url]
|
||||||
@@ -135,17 +135,17 @@ Put your VAE in: models/vae
|
|||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
@@ -230,7 +230,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||||
|
|
||||||
## How to use TLS/SSL?
|
## How to use TLS/SSL?
|
||||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from aiohttp import web
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from folder_paths import models_dir, user_directory, output_directory
|
from folder_paths import models_dir, user_directory, output_directory
|
||||||
from api_server.services.file_service import FileService
|
from api_server.services.file_service import FileService
|
||||||
|
import app.logger
|
||||||
|
|
||||||
class InternalRoutes:
|
class InternalRoutes:
|
||||||
'''
|
'''
|
||||||
@@ -31,6 +32,9 @@ class InternalRoutes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
@self.routes.get('/logs')
|
||||||
|
async def get_logs(request):
|
||||||
|
return web.json_response(app.logger.get_logs())
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import zipfile
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
@@ -132,12 +132,13 @@ class FrontendManager:
|
|||||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Initializes the frontend for the specified version.
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version_string (str): The version string.
|
version_string (str): The version string.
|
||||||
|
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the initialized frontend.
|
str: The path to the initialized frontend.
|
||||||
@@ -150,7 +151,7 @@ class FrontendManager:
|
|||||||
return cls.DEFAULT_FRONTEND_PATH
|
return cls.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
provider = FrontEndProvider(repo_owner, repo_name)
|
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||||
release = provider.get_release(version)
|
release = provider.get_release(version)
|
||||||
|
|
||||||
semantic_version = release["tag_name"].lstrip("v")
|
semantic_version = release["tag_name"].lstrip("v")
|
||||||
@@ -158,6 +159,7 @@ class FrontendManager:
|
|||||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||||
)
|
)
|
||||||
if not os.path.exists(web_root):
|
if not os.path.exists(web_root):
|
||||||
|
try:
|
||||||
os.makedirs(web_root, exist_ok=True)
|
os.makedirs(web_root, exist_ok=True)
|
||||||
logging.info(
|
logging.info(
|
||||||
"Downloading frontend(%s) version(%s) to (%s)",
|
"Downloading frontend(%s) version(%s) to (%s)",
|
||||||
@@ -167,6 +169,11 @@ class FrontendManager:
|
|||||||
)
|
)
|
||||||
logging.debug(release)
|
logging.debug(release)
|
||||||
download_release_asset_zip(release, destination_path=web_root)
|
download_release_asset_zip(release, destination_path=web_root)
|
||||||
|
finally:
|
||||||
|
# Clean up the directory if it is empty, i.e. the download failed
|
||||||
|
if not os.listdir(web_root):
|
||||||
|
os.rmdir(web_root)
|
||||||
|
|
||||||
return web_root
|
return web_root
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
31
app/logger.py
Normal file
31
app/logger.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
from logging.handlers import MemoryHandler
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
logs = None
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logs():
|
||||||
|
return "\n".join([formatter.format(x) for x in logs])
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(verbose: bool = False, capacity: int = 300):
|
||||||
|
global logs
|
||||||
|
if logs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup default global logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# Create a memory handler with a deque as its buffer
|
||||||
|
logs = deque(maxlen=capacity)
|
||||||
|
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||||
|
memory_handler.buffer = logs
|
||||||
|
memory_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(memory_handler)
|
||||||
@@ -179,10 +179,3 @@ if args.windows_standalone_build:
|
|||||||
|
|
||||||
if args.disable_auto_launch:
|
if args.disable_auto_launch:
|
||||||
args.auto_launch = False
|
args.auto_launch = False
|
||||||
|
|
||||||
import logging
|
|
||||||
logging_level = logging.INFO
|
|
||||||
if args.verbose:
|
|
||||||
logging_level = logging.DEBUG
|
|
||||||
|
|
||||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ import comfy.t2i_adapter.adapter
|
|||||||
import comfy.ldm.cascade.controlnet
|
import comfy.ldm.cascade.controlnet
|
||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet_xlabs
|
import comfy.ldm.flux.controlnet
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
@@ -148,7 +148,7 @@ class ControlBase:
|
|||||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||||
x *= (self.strength ** float(len(control_output) - i))
|
x *= (self.strength ** float(len(control_output) - i))
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if output_dtype is not None and x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
out[key].append(x)
|
out[key].append(x)
|
||||||
@@ -206,7 +206,6 @@ class ControlNet(ControlBase):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
@@ -236,7 +235,7 @@ class ControlNet(ControlBase):
|
|||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype=None)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
@@ -391,7 +390,8 @@ def controlnet_config(sd):
|
|||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
@@ -405,12 +405,12 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
@@ -420,9 +420,9 @@ def load_controlnet_mmdit(sd):
|
|||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data):
|
def load_controlnet_hunyuandit(controlnet_data):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SDXL()
|
latent_format = comfy.latent_formats.SDXL()
|
||||||
@@ -431,13 +431,31 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs(sd):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||||
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_instantx(sd):
|
||||||
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
|
num_union_modes = 0
|
||||||
|
union_cnet = "controlnet_mode_embedder.weight"
|
||||||
|
if union_cnet in new_sd:
|
||||||
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Flux()
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
@@ -503,8 +521,10 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs(controlnet_data)
|
||||||
else:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data)
|
||||||
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
|
return load_controlnet_flux_instantx(controlnet_data)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@@ -536,6 +556,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if manual_cast_dtype is not None:
|
if manual_cast_dtype is not None:
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||||
controlnet_config["dtype"] = unet_dtype
|
controlnet_config["dtype"] = unet_dtype
|
||||||
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
|
|||||||
@@ -1,7 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||||
|
mantissa_scaled = torch.where(
|
||||||
|
normal_mask,
|
||||||
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||||
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||||
|
)
|
||||||
|
|
||||||
|
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||||
|
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||||
|
|
||||||
#Not 100% sure about this
|
#Not 100% sure about this
|
||||||
def manual_stochastic_round_to_float8(x, dtype):
|
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn:
|
||||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||||
elif dtype == torch.float8_e5m2:
|
elif dtype == torch.float8_e5m2:
|
||||||
@@ -9,44 +20,33 @@ def manual_stochastic_round_to_float8(x, dtype):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported dtype")
|
raise ValueError("Unsupported dtype")
|
||||||
|
|
||||||
|
x = x.half()
|
||||||
sign = torch.sign(x)
|
sign = torch.sign(x)
|
||||||
abs_x = x.abs()
|
abs_x = x.abs()
|
||||||
|
sign = torch.where(abs_x == 0, 0, sign)
|
||||||
|
|
||||||
# Combine exponent calculation and clamping
|
# Combine exponent calculation and clamping
|
||||||
exponent = torch.clamp(
|
exponent = torch.clamp(
|
||||||
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
|
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||||
0, 2**EXPONENT_BITS - 1
|
0, 2**EXPONENT_BITS - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine mantissa calculation and rounding
|
# Combine mantissa calculation and rounding
|
||||||
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
|
||||||
# zero_mask = (abs_x == 0)
|
|
||||||
# subnormal_mask = (exponent == 0) & (abs_x != 0)
|
|
||||||
normal_mask = ~(exponent == 0)
|
normal_mask = ~(exponent == 0)
|
||||||
|
|
||||||
mantissa_scaled = torch.where(
|
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||||
|
|
||||||
|
sign *= torch.where(
|
||||||
normal_mask,
|
normal_mask,
|
||||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
)
|
|
||||||
mantissa_floor = mantissa_scaled.floor()
|
|
||||||
mantissa = torch.where(
|
|
||||||
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
|
||||||
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
|
||||||
mantissa_floor / (2**MANTISSA_BITS)
|
|
||||||
)
|
|
||||||
result = torch.where(
|
|
||||||
normal_mask,
|
|
||||||
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
|
|
||||||
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = torch.where(abs_x == 0, 0, result)
|
return sign
|
||||||
return result.to(dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def stochastic_rounding(value, dtype):
|
def stochastic_rounding(value, dtype, seed=0):
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
return value.to(dtype=torch.float32)
|
return value.to(dtype=torch.float32)
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16:
|
||||||
@@ -54,6 +54,13 @@ def stochastic_rounding(value, dtype):
|
|||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return value.to(dtype=torch.bfloat16)
|
return value.to(dtype=torch.bfloat16)
|
||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
return manual_stochastic_round_to_float8(value, dtype)
|
generator = torch.Generator(device=value.device)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
for i in range(0, value.shape[0], slice_size):
|
||||||
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||||
|
return output
|
||||||
|
|
||||||
return value.to(dtype=dtype)
|
return value.to(dtype=dtype)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
@@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
def rms_norm(x, weight, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|||||||
151
comfy/ldm/flux/controlnet.py
Normal file
151
comfy/ldm/flux/controlnet.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||||
|
MLPEmbedder, SingleStreamBlock,
|
||||||
|
timestep_embedding)
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFlux(Flux):
|
||||||
|
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
self.main_model_double = 19
|
||||||
|
self.main_model_single = 38
|
||||||
|
# add ControlNet blocks
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth_single_blocks):
|
||||||
|
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
self.num_union_modes = num_union_modes
|
||||||
|
self.controlnet_mode_embedder = None
|
||||||
|
if self.num_union_modes > 0:
|
||||||
|
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.latent_input = latent_input
|
||||||
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
if not self.latent_input:
|
||||||
|
self.input_hint_block = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
control_type: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
if not self.latent_input:
|
||||||
|
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||||
|
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
|
||||||
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
|
img = img + controlnet_cond
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||||
|
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||||
|
txt = torch.cat([control_cond, txt], dim=1)
|
||||||
|
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
controlnet_double = ()
|
||||||
|
|
||||||
|
for i in range(len(self.double_blocks)):
|
||||||
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
controlnet_single = ()
|
||||||
|
|
||||||
|
for i in range(len(self.single_blocks)):
|
||||||
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||||
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||||
|
if self.latent_input:
|
||||||
|
out_input = ()
|
||||||
|
for x in controlnet_double:
|
||||||
|
out_input += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_input = (controlnet_double * repeat)
|
||||||
|
|
||||||
|
out = {"input": out_input[:self.main_model_double]}
|
||||||
|
if len(controlnet_single) > 0:
|
||||||
|
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||||
|
out_output = ()
|
||||||
|
if self.latent_input:
|
||||||
|
for x in controlnet_single:
|
||||||
|
out_output += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_output = (controlnet_single * repeat)
|
||||||
|
out["output"] = out_output[:self.main_model_single]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
|
patch_size = 2
|
||||||
|
if self.latent_input:
|
||||||
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
else:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
||||||
MLPEmbedder, SingleStreamBlock,
|
|
||||||
timestep_embedding)
|
|
||||||
|
|
||||||
from .model import Flux
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
||||||
|
|
||||||
# add ControlNet blocks
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(self.params.depth):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
# controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
self.input_hint_block = nn.Sequential(
|
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
img: Tensor,
|
|
||||||
img_ids: Tensor,
|
|
||||||
controlnet_cond: Tensor,
|
|
||||||
txt: Tensor,
|
|
||||||
txt_ids: Tensor,
|
|
||||||
timesteps: Tensor,
|
|
||||||
y: Tensor,
|
|
||||||
guidance: Tensor = None,
|
|
||||||
) -> Tensor:
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
# running on sequences img
|
|
||||||
img = self.img_in(img)
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
|
||||||
img = img + controlnet_cond
|
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
||||||
if self.params.guidance_embed:
|
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
||||||
vec = vec + self.vector_in(y)
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
pe = self.pe_embedder(ids)
|
|
||||||
|
|
||||||
block_res_samples = ()
|
|
||||||
|
|
||||||
for block in self.double_blocks:
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
||||||
block_res_samples = block_res_samples + (img,)
|
|
||||||
|
|
||||||
controlnet_block_res_samples = ()
|
|
||||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
||||||
block_res_sample = controlnet_block(block_res_sample)
|
|
||||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
||||||
|
|
||||||
return {"input": (controlnet_block_res_samples * 10)[:19]}
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
|
||||||
hint = hint * 2.0 - 1.0
|
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
patch_size = 2
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
||||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
|
||||||
@@ -6,6 +6,7 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
x_dtype = x.dtype
|
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||||
x = x.float()
|
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
|
||||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
@@ -178,7 +176,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = txt.clip(-65504, 65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
@@ -233,7 +231,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x += mod.gate * output
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = x.clip(-65504, 65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
if layer > self.depth // 2:
|
if layer > self.depth // 2:
|
||||||
if controls is not None:
|
if controls is not None:
|
||||||
skip = skips.pop() + controls.pop()
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
|||||||
@@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
"""
|
|
||||||
Apply the RMSNorm normalization to the input tensor.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
Forward pass through the RMSNorm layer.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The output tensor after applying RMSNorm.
|
|
||||||
"""
|
|
||||||
x = self._norm(x)
|
|
||||||
if self.learnable_scale:
|
|
||||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUFeedForward(nn.Module):
|
class SwiGLUFeedForward(nn.Module):
|
||||||
|
|||||||
253
comfy/lora.py
253
comfy/lora.py
@@ -16,8 +16,12 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.model_base
|
||||||
import logging
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
"mlp.fc1": "mlp_fc1",
|
"mlp.fc1": "mlp_fc1",
|
||||||
@@ -320,5 +324,254 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
|
for p in patches:
|
||||||
|
strength = p[0]
|
||||||
|
v = p[1]
|
||||||
|
strength_model = p[2]
|
||||||
|
offset = p[3]
|
||||||
|
function = p[4]
|
||||||
|
if function is None:
|
||||||
|
function = lambda a: a
|
||||||
|
|
||||||
|
old_weight = None
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = weight
|
||||||
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
|
||||||
|
if strength_model != 1.0:
|
||||||
|
weight *= strength_model
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
|
if len(v) == 1:
|
||||||
|
patch_type = "diff"
|
||||||
|
elif len(v) == 2:
|
||||||
|
patch_type = v[0]
|
||||||
|
v = v[1]
|
||||||
|
|
||||||
|
if patch_type == "diff":
|
||||||
|
diff: torch.Tensor = v[0]
|
||||||
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||||
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||||
|
if do_pad_weight and diff.shape != weight.shape:
|
||||||
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||||
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||||
|
|
||||||
|
if strength != 0.0:
|
||||||
|
if diff.shape != weight.shape:
|
||||||
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
|
else:
|
||||||
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
|
elif patch_type == "lora": #lora/locon
|
||||||
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
|
dora_scale = v[4]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "lokr":
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "loha":
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "glora":
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
else:
|
||||||
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|||||||
@@ -473,8 +473,14 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||||
|
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||||
|
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||||
|
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
|||||||
@@ -44,9 +44,15 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
lowvram_available = True
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
torch_version = ""
|
||||||
|
try:
|
||||||
|
torch_version = torch.version.__version__
|
||||||
|
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
lowvram_available = True
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
logging.info("Using deterministic algorithms for pytorch")
|
logging.info("Using deterministic algorithms for pytorch")
|
||||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
@@ -66,10 +72,10 @@ if args.directml is not None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
if torch.xpu.is_available():
|
_ = torch.xpu.device_count()
|
||||||
xpu_available = True
|
xpu_available = torch.xpu.is_available()
|
||||||
except:
|
except:
|
||||||
pass
|
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@@ -189,7 +195,6 @@ VAE_DTYPES = [torch.float32]
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
torch_version = torch.version.__version__
|
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@@ -321,8 +326,9 @@ class LoadedModel:
|
|||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
with torch.no_grad():
|
||||||
|
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
self.weights_loaded = True
|
self.weights_loaded = True
|
||||||
return self.real_model
|
return self.real_model
|
||||||
@@ -364,12 +370,11 @@ def offloaded_memory(loaded_models, device):
|
|||||||
offloaded_mem += m.model_offloaded_memory()
|
offloaded_mem += m.model_offloaded_memory()
|
||||||
return offloaded_mem
|
return offloaded_mem
|
||||||
|
|
||||||
def minimum_inference_memory():
|
WINDOWS = any(platform.win32_ver())
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
|
||||||
|
|
||||||
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if any(platform.win32_ver()):
|
if WINDOWS:
|
||||||
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
@@ -378,6 +383,9 @@ if args.reserve_vram is not None:
|
|||||||
def extra_reserved_memory():
|
def extra_reserved_memory():
|
||||||
return EXTRA_RESERVED_VRAM
|
return EXTRA_RESERVED_VRAM
|
||||||
|
|
||||||
|
def minimum_inference_memory():
|
||||||
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||||
to_unload = []
|
to_unload = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
@@ -400,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
|||||||
if not force_unload:
|
if not force_unload:
|
||||||
if unload_weights_only and unload_weight == False:
|
if unload_weights_only and unload_weight == False:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
unload_weight = True
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||||
@@ -561,7 +571,9 @@ def loaded_models(only_currently_used=False):
|
|||||||
def cleanup_models(keep_clone_weights_loaded=False):
|
def cleanup_models(keep_clone_weights_loaded=False):
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
#TODO: very fragile function needs improvement
|
||||||
|
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||||
|
if num_refs <= 2:
|
||||||
if not keep_clone_weights_loaded:
|
if not keep_clone_weights_loaded:
|
||||||
to_delete = [i] + to_delete
|
to_delete = [i] + to_delete
|
||||||
#TODO: find a less fragile way to do this.
|
#TODO: find a less fragile way to do this.
|
||||||
@@ -668,6 +680,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||||
for dt in supported_dtypes:
|
for dt in supported_dtypes:
|
||||||
if dt == torch.float16 and fp16_supported:
|
if dt == torch.float16 and fp16_supported:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
@@ -883,7 +896,8 @@ def pytorch_attention_flash_attention():
|
|||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -986,16 +1000,16 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if props.major < 6:
|
if props.major < 6:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
fp16_works = False
|
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
||||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
|
||||||
#when the model doesn't actually fit on the card
|
|
||||||
#TODO: actually test if GP106 and others have the same type of behavior
|
|
||||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||||
for x in nvidia_10_series:
|
for x in nvidia_10_series:
|
||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
fp16_works = True
|
if WINDOWS or manual_cast:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False #weird linux behavior where fp32 is faster
|
||||||
|
|
||||||
if fp16_works or manual_cast:
|
if manual_cast:
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -27,29 +27,21 @@ import math
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.lora
|
||||||
from comfy.types import UnetWrapperFunction
|
from comfy.types import UnetWrapperFunction
|
||||||
|
|
||||||
|
def string_to_seed(data):
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
crc = 0xFFFFFFFF
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
for byte in data:
|
||||||
lora_diff *= alpha
|
if isinstance(byte, str):
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
byte = ord(byte)
|
||||||
weight_norm = (
|
crc ^= byte
|
||||||
weight_calc.transpose(0, 1)
|
for _ in range(8):
|
||||||
.reshape(weight_calc.shape[1], -1)
|
if crc & 1:
|
||||||
.norm(dim=1, keepdim=True)
|
crc = (crc >> 1) ^ 0xEDB88320
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
else:
|
||||||
weight[:] = weight_calc
|
crc >>= 1
|
||||||
return weight
|
return crc ^ 0xFFFFFFFF
|
||||||
|
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
@@ -92,12 +84,11 @@ def wipe_lowvram_weight(m):
|
|||||||
m.bias_function = None
|
m.bias_function = None
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, model_patcher):
|
def __init__(self, key, patches):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_patcher = model_patcher
|
self.patches = patches
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
@@ -329,8 +320,8 @@ class ModelPatcher:
|
|||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype)
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
@@ -340,12 +331,21 @@ class ModelPatcher:
|
|||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
load_completely = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m))
|
||||||
|
|
||||||
|
load_completely = []
|
||||||
|
loading.sort(reverse=True)
|
||||||
|
for x in loading:
|
||||||
|
n = x[1]
|
||||||
|
m = x[2]
|
||||||
|
module_mem = x[0]
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
@@ -360,13 +360,13 @@ class ModelPatcher:
|
|||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
@@ -377,9 +377,8 @@ class ModelPatcher:
|
|||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
mem_used = comfy.model_management.module_size(m)
|
mem_counter += module_mem
|
||||||
mem_counter += mem_used
|
load_completely.append((module_mem, n, m))
|
||||||
load_completely.append((mem_used, n, m))
|
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
@@ -428,174 +427,6 @@ class ModelPatcher:
|
|||||||
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
|
||||||
for p in patches:
|
|
||||||
strength = p[0]
|
|
||||||
v = p[1]
|
|
||||||
strength_model = p[2]
|
|
||||||
offset = p[3]
|
|
||||||
function = p[4]
|
|
||||||
if function is None:
|
|
||||||
function = lambda a: a
|
|
||||||
|
|
||||||
old_weight = None
|
|
||||||
if offset is not None:
|
|
||||||
old_weight = weight
|
|
||||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
|
||||||
|
|
||||||
if strength_model != 1.0:
|
|
||||||
weight *= strength_model
|
|
||||||
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
|
||||||
|
|
||||||
if len(v) == 1:
|
|
||||||
patch_type = "diff"
|
|
||||||
elif len(v) == 2:
|
|
||||||
patch_type = v[0]
|
|
||||||
v = v[1]
|
|
||||||
|
|
||||||
if patch_type == "diff":
|
|
||||||
w1 = v[0]
|
|
||||||
if strength != 0.0:
|
|
||||||
if w1.shape != weight.shape:
|
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
||||||
else:
|
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
|
||||||
dora_scale = v[4]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
|
||||||
|
|
||||||
if old_weight is not None:
|
|
||||||
weight = old_weight
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
@@ -664,10 +495,10 @@ class ModelPatcher:
|
|||||||
|
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
@@ -695,3 +526,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
return self.model.device
|
return self.model.device
|
||||||
|
|
||||||
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
|
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||||
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|||||||
34
comfy/ops.py
34
comfy/ops.py
@@ -20,31 +20,40 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
return weight
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
has_function = s.bias_function is not None
|
||||||
if s.bias_function is not None:
|
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
|
||||||
if s.weight_function is not None:
|
has_function = s.weight_function is not None
|
||||||
|
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
@@ -252,7 +261,8 @@ def fp8_linear(self, input):
|
|||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
@@ -263,8 +273,8 @@ def fp8_linear(self, input):
|
|||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
if self.bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
else:
|
else:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
|
|||||||
@@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||||
|
dtype_t5 = None
|
||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
dtype_t5 = state_dict[t5_key].dtype
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
||||||
|
|||||||
@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
||||||
|
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = (node_id, node["class_type"])
|
self.keys[node_id] = (node_id, node["class_type"])
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
@@ -74,6 +76,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
@@ -87,6 +91,9 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return to_hashable(signature)
|
return to_hashable(signature)
|
||||||
|
|
||||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
# This node doesn't exist -- we can't cache it.
|
||||||
|
return [float("NaN")]
|
||||||
node = dynprompt.get_node(node_id)
|
node = dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
@@ -112,6 +119,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return ancestors, order_mapping
|
return ancestors, order_mapping
|
||||||
|
|
||||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
return
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
input_keys = sorted(inputs.keys())
|
input_keys = sorted(inputs.keys())
|
||||||
for key in input_keys:
|
for key in input_keys:
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ class IsChangedCache:
|
|||||||
self.is_changed[node_id] = node["is_changed"]
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
|
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
@@ -491,6 +492,7 @@ class PromptExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||||
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
break
|
break
|
||||||
|
|||||||
4
main.py
4
main.py
@@ -6,6 +6,10 @@ import importlib.util
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from app.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
setup_logger(verbose=args.verbose)
|
||||||
|
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@@ -2129,3 +2129,5 @@ def init_extra_nodes(init_custom_nodes=True):
|
|||||||
else:
|
else:
|
||||||
logging.warning("Please do a: pip install -r requirements.txt")
|
logging.warning("Please do a: pip install -r requirements.txt")
|
||||||
logging.warning("")
|
logging.warning("")
|
||||||
|
|
||||||
|
return import_failed
|
||||||
|
|||||||
@@ -79,7 +79,7 @@
|
|||||||
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD1.5\n",
|
"# SD1.5\n",
|
||||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD2\n",
|
"# SD2\n",
|
||||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
|||||||
@@ -41,11 +41,10 @@ def get_images(ws, prompt):
|
|||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = get_history(prompt_id)[prompt_id]
|
history = get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
|
||||||
for node_id in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
node_output = history['outputs'][node_id]
|
node_output = history['outputs'][node_id]
|
||||||
if 'images' in node_output:
|
|
||||||
images_output = []
|
images_output = []
|
||||||
|
if 'images' in node_output:
|
||||||
for image in node_output['images']:
|
for image in node_output['images']:
|
||||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
images_output.append(image_data)
|
images_output.append(image_data)
|
||||||
@@ -85,7 +84,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
|||||||
28
server.py
28
server.py
@@ -31,7 +31,6 @@ from model_filemanager import download_model, DownloadModelStatus
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
UNENCODED_PREVIEW_IMAGE = 2
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
@@ -42,6 +41,21 @@ async def send_socket_catch_exception(function, message):
|
|||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||||
logging.warning("send error: {}".format(err))
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
|
def get_comfyui_version():
|
||||||
|
comfyui_version = "unknown"
|
||||||
|
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
try:
|
||||||
|
import pygit2
|
||||||
|
repo = pygit2.Repository(repo_path)
|
||||||
|
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to get ComfyUI version: {e}")
|
||||||
|
return comfyui_version.strip()
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
@@ -401,16 +415,20 @@ class PromptServer():
|
|||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def get_queue(request):
|
async def system_stats(request):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": os.name,
|
||||||
|
"comfyui_version": get_comfyui_version(),
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
|
"argv": sys.argv
|
||||||
},
|
},
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
@@ -586,7 +604,9 @@ class PromptServer():
|
|||||||
@routes.post("/internal/models/download")
|
@routes.post("/internal/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
await self.send_json("download_progress", status.to_dict())
|
payload = status.to_dict()
|
||||||
|
payload['download_path'] = filename
|
||||||
|
await self.send_json("download_progress", payload)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import pytest
|
import pytest
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from app.frontend_management import (
|
from app.frontend_management import (
|
||||||
FrontendManager,
|
FrontendManager,
|
||||||
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
|
|||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
FrontendManager.init_frontend_unsafe(version_string)
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_os_functions():
|
||||||
|
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||||
|
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||||
|
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||||
|
mock_listdir.return_value = [] # Simulate empty directory
|
||||||
|
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_download():
|
||||||
|
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||||
|
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||||
|
# Arrange
|
||||||
|
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||||
|
version_string = 'test-owner/test-repo@1.0.0'
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_makedirs.assert_called_once()
|
||||||
|
mock_download.assert_called_once()
|
||||||
|
mock_listdir.assert_called_once()
|
||||||
|
mock_rmdir.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_parse_version_string():
|
def test_parse_version_string():
|
||||||
version_string = "owner/repo@1.0.0"
|
version_string = "owner/repo@1.0.0"
|
||||||
|
|||||||
@@ -95,12 +95,11 @@ class ComfyClient:
|
|||||||
pass # Probably want to store this off for testing
|
pass # Probably want to store this off for testing
|
||||||
|
|
||||||
history = self.get_history(prompt_id)[prompt_id]
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
|
||||||
for node_id in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
node_output = history['outputs'][node_id]
|
node_output = history['outputs'][node_id]
|
||||||
result.outputs[node_id] = node_output
|
result.outputs[node_id] = node_output
|
||||||
if 'images' in node_output:
|
|
||||||
images_output = []
|
images_output = []
|
||||||
|
if 'images' in node_output:
|
||||||
for image in node_output['images']:
|
for image in node_output['images']:
|
||||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
image_obj = Image.open(BytesIO(image_data))
|
image_obj = Image.open(BytesIO(image_data))
|
||||||
@@ -357,6 +356,25 @@ class TestExecution:
|
|||||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||||
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
||||||
|
|
||||||
|
def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||||
|
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||||
|
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
|
||||||
|
# We have multiple outputs. The first is invalid, but the second is valid
|
||||||
|
g.node("SaveImage", images=mix1.out(0))
|
||||||
|
g.node("SaveImage", images=mix2.out(0))
|
||||||
|
g.remove_node("removeme")
|
||||||
|
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
|
# Add back in the missing node to make sure the error doesn't break the server
|
||||||
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
# Creating the nodes in this specific order previously caused a bug
|
# Creating the nodes in this specific order previously caused a bug
|
||||||
@@ -450,8 +468,8 @@ class TestExecution:
|
|||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
output1 = g.node("PreviewImage", images=input1.out(0))
|
output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
output2 = g.node("PreviewImage", images=input1.out(0))
|
output2 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = client.run(g)
|
||||||
images1 = result.get_images(output1)
|
images1 = result.get_images(output1)
|
||||||
@@ -459,3 +477,22 @@ class TestExecution:
|
|||||||
assert len(images1) == 1, "Should have 1 image"
|
assert len(images1) == 1, "Should have 1 image"
|
||||||
assert len(images2) == 1, "Should have 1 image"
|
assert len(images2) == 1, "Should have 1 image"
|
||||||
|
|
||||||
|
|
||||||
|
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||||
|
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||||
|
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||||
|
|
||||||
|
output = g.node("PreviewImage", images=test_node.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|||||||
@@ -109,11 +109,10 @@ class ComfyClient:
|
|||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = self.get_history(prompt_id)[prompt_id]
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
|
||||||
for node_id in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
node_output = history['outputs'][node_id]
|
node_output = history['outputs'][node_id]
|
||||||
if 'images' in node_output:
|
|
||||||
images_output = []
|
images_output = []
|
||||||
|
if 'images' in node_output:
|
||||||
for image in node_output['images']:
|
for image in node_output['images']:
|
||||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
images_output.append(image_data)
|
images_output.append(image_data)
|
||||||
|
|||||||
@@ -95,6 +95,31 @@ class TestCustomIsChanged:
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class TestIsChangedWithConstants:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "custom_is_changed"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def custom_is_changed(self, image, value):
|
||||||
|
return (image * value,)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, image, value):
|
||||||
|
if image is None:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return image.mean().item() * value
|
||||||
|
|
||||||
class TestCustomValidation1:
|
class TestCustomValidation1:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
@@ -312,6 +337,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestLazyMixImages": TestLazyMixImages,
|
"TestLazyMixImages": TestLazyMixImages,
|
||||||
"TestVariadicAverage": TestVariadicAverage,
|
"TestVariadicAverage": TestVariadicAverage,
|
||||||
"TestCustomIsChanged": TestCustomIsChanged,
|
"TestCustomIsChanged": TestCustomIsChanged,
|
||||||
|
"TestIsChangedWithConstants": TestIsChangedWithConstants,
|
||||||
"TestCustomValidation1": TestCustomValidation1,
|
"TestCustomValidation1": TestCustomValidation1,
|
||||||
"TestCustomValidation2": TestCustomValidation2,
|
"TestCustomValidation2": TestCustomValidation2,
|
||||||
"TestCustomValidation3": TestCustomValidation3,
|
"TestCustomValidation3": TestCustomValidation3,
|
||||||
@@ -325,6 +351,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestLazyMixImages": "Lazy Mix Images",
|
"TestLazyMixImages": "Lazy Mix Images",
|
||||||
"TestVariadicAverage": "Variadic Average",
|
"TestVariadicAverage": "Variadic Average",
|
||||||
"TestCustomIsChanged": "Custom IsChanged",
|
"TestCustomIsChanged": "Custom IsChanged",
|
||||||
|
"TestIsChangedWithConstants": "IsChanged With Constants",
|
||||||
"TestCustomValidation1": "Custom Validation 1",
|
"TestCustomValidation1": "Custom Validation 1",
|
||||||
"TestCustomValidation2": "Custom Validation 2",
|
"TestCustomValidation2": "Custom Validation 2",
|
||||||
"TestCustomValidation3": "Custom Validation 3",
|
"TestCustomValidation3": "Custom Validation 3",
|
||||||
|
|||||||
@@ -28,6 +28,28 @@ class StubImage:
|
|||||||
elif content == "NOISE":
|
elif content == "NOISE":
|
||||||
return (torch.rand(batch_size, height, width, 3),)
|
return (torch.rand(batch_size, height, width, 3),)
|
||||||
|
|
||||||
|
class StubConstantImage:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "stub_constant_image"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Stub Nodes"
|
||||||
|
|
||||||
|
def stub_constant_image(self, value, height, width, batch_size):
|
||||||
|
return (torch.ones(batch_size, height, width, 3) * value,)
|
||||||
|
|
||||||
class StubMask:
|
class StubMask:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -93,12 +115,14 @@ class StubFloat:
|
|||||||
|
|
||||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||||
"StubImage": StubImage,
|
"StubImage": StubImage,
|
||||||
|
"StubConstantImage": StubConstantImage,
|
||||||
"StubMask": StubMask,
|
"StubMask": StubMask,
|
||||||
"StubInt": StubInt,
|
"StubInt": StubInt,
|
||||||
"StubFloat": StubFloat,
|
"StubFloat": StubFloat,
|
||||||
}
|
}
|
||||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"StubImage": "Stub Image",
|
"StubImage": "Stub Image",
|
||||||
|
"StubConstantImage": "Stub Constant Image",
|
||||||
"StubMask": "Stub Mask",
|
"StubMask": "Stub Mask",
|
||||||
"StubInt": "Stub Int",
|
"StubInt": "Stub Int",
|
||||||
"StubFloat": "Stub Float",
|
"StubFloat": "Stub Float",
|
||||||
|
|||||||
1208
web/assets/index-DkvOTKox.js → web/assets/index-BD-Ia1C4.js
generated
vendored
1208
web/assets/index-DkvOTKox.js → web/assets/index-BD-Ia1C4.js
generated
vendored
File diff suppressed because it is too large
Load Diff
1
web/assets/index-BD-Ia1C4.js.map
generated
vendored
Normal file
1
web/assets/index-BD-Ia1C4.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
84185
web/assets/index-CaD4RONs.js → web/assets/index-CI3N807S.js
generated
vendored
84185
web/assets/index-CaD4RONs.js → web/assets/index-CI3N807S.js
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/index-CI3N807S.js.map
generated
vendored
Normal file
1
web/assets/index-CI3N807S.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/index-CaD4RONs.js.map
generated
vendored
1
web/assets/index-CaD4RONs.js.map
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/index-DkvOTKox.js.map
generated
vendored
1
web/assets/index-DkvOTKox.js.map
generated
vendored
File diff suppressed because one or more lines are too long
633
web/assets/index-DAK31IJJ.css → web/assets/index-_5czGnTA.css
generated
vendored
633
web/assets/index-DAK31IJJ.css → web/assets/index-_5czGnTA.css
generated
vendored
File diff suppressed because it is too large
Load Diff
120
web/assets/userSelection-CyXKCVy3.js
generated
vendored
Normal file
120
web/assets/userSelection-CyXKCVy3.js
generated
vendored
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { j as createSpinner, h as api, $ as $el } from "./index-CI3N807S.js";
|
||||||
|
class UserSelectionScreen {
|
||||||
|
static {
|
||||||
|
__name(this, "UserSelectionScreen");
|
||||||
|
}
|
||||||
|
async show(users, user) {
|
||||||
|
const userSelection = document.getElementById("comfy-user-selection");
|
||||||
|
userSelection.style.display = "";
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const input = userSelection.getElementsByTagName("input")[0];
|
||||||
|
const select = userSelection.getElementsByTagName("select")[0];
|
||||||
|
const inputSection = input.closest("section");
|
||||||
|
const selectSection = select.closest("section");
|
||||||
|
const form = userSelection.getElementsByTagName("form")[0];
|
||||||
|
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
|
||||||
|
const button = userSelection.getElementsByClassName(
|
||||||
|
"comfy-user-button-next"
|
||||||
|
)[0];
|
||||||
|
let inputActive = null;
|
||||||
|
input.addEventListener("focus", () => {
|
||||||
|
inputSection.classList.add("selected");
|
||||||
|
selectSection.classList.remove("selected");
|
||||||
|
inputActive = true;
|
||||||
|
});
|
||||||
|
select.addEventListener("focus", () => {
|
||||||
|
inputSection.classList.remove("selected");
|
||||||
|
selectSection.classList.add("selected");
|
||||||
|
inputActive = false;
|
||||||
|
select.style.color = "";
|
||||||
|
});
|
||||||
|
select.addEventListener("blur", () => {
|
||||||
|
if (!select.value) {
|
||||||
|
select.style.color = "var(--descrip-text)";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
form.addEventListener("submit", async (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
if (inputActive == null) {
|
||||||
|
error.textContent = "Please enter a username or select an existing user.";
|
||||||
|
} else if (inputActive) {
|
||||||
|
const username = input.value.trim();
|
||||||
|
if (!username) {
|
||||||
|
error.textContent = "Please enter a username.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
input.disabled = select.disabled = // @ts-expect-error
|
||||||
|
input.readonly = // @ts-expect-error
|
||||||
|
select.readonly = true;
|
||||||
|
const spinner = createSpinner();
|
||||||
|
button.prepend(spinner);
|
||||||
|
try {
|
||||||
|
const resp = await api.createUser(username);
|
||||||
|
if (resp.status >= 300) {
|
||||||
|
let message = "Error creating user: " + resp.status + " " + resp.statusText;
|
||||||
|
try {
|
||||||
|
const res = await resp.json();
|
||||||
|
if (res.error) {
|
||||||
|
message = res.error;
|
||||||
|
}
|
||||||
|
} catch (error2) {
|
||||||
|
}
|
||||||
|
throw new Error(message);
|
||||||
|
}
|
||||||
|
resolve({ username, userId: await resp.json(), created: true });
|
||||||
|
} catch (err) {
|
||||||
|
spinner.remove();
|
||||||
|
error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred.";
|
||||||
|
input.disabled = select.disabled = // @ts-expect-error
|
||||||
|
input.readonly = // @ts-expect-error
|
||||||
|
select.readonly = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if (!select.value) {
|
||||||
|
error.textContent = "Please select an existing user.";
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
resolve({
|
||||||
|
username: users[select.value],
|
||||||
|
userId: select.value,
|
||||||
|
created: false
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (user) {
|
||||||
|
const name = localStorage["Comfy.userName"];
|
||||||
|
if (name) {
|
||||||
|
input.value = name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (input.value) {
|
||||||
|
input.focus();
|
||||||
|
}
|
||||||
|
const userIds = Object.keys(users ?? {});
|
||||||
|
if (userIds.length) {
|
||||||
|
for (const u of userIds) {
|
||||||
|
$el("option", { textContent: users[u], value: u, parent: select });
|
||||||
|
}
|
||||||
|
select.style.color = "var(--descrip-text)";
|
||||||
|
if (select.value) {
|
||||||
|
select.focus();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
userSelection.classList.add("no-users");
|
||||||
|
input.focus();
|
||||||
|
}
|
||||||
|
}).then((r) => {
|
||||||
|
userSelection.remove();
|
||||||
|
return r;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
window.comfyAPI = window.comfyAPI || {};
|
||||||
|
window.comfyAPI.userSelection = window.comfyAPI.userSelection || {};
|
||||||
|
window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
||||||
|
export {
|
||||||
|
UserSelectionScreen
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=userSelection-CyXKCVy3.js.map
|
||||||
1
web/assets/userSelection-CyXKCVy3.js.map
generated
vendored
Normal file
1
web/assets/userSelection-CyXKCVy3.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
142
web/assets/userSelection-GRU1gtOt.js
generated
vendored
142
web/assets/userSelection-GRU1gtOt.js
generated
vendored
@@ -1,142 +0,0 @@
|
|||||||
var __defProp = Object.defineProperty;
|
|
||||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
|
||||||
var __async = (__this, __arguments, generator) => {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
var fulfilled = (value) => {
|
|
||||||
try {
|
|
||||||
step(generator.next(value));
|
|
||||||
} catch (e) {
|
|
||||||
reject(e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
var rejected = (value) => {
|
|
||||||
try {
|
|
||||||
step(generator.throw(value));
|
|
||||||
} catch (e) {
|
|
||||||
reject(e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
var step = (x) => x.done ? resolve(x.value) : Promise.resolve(x.value).then(fulfilled, rejected);
|
|
||||||
step((generator = generator.apply(__this, __arguments)).next());
|
|
||||||
});
|
|
||||||
};
|
|
||||||
import { j as createSpinner, g as api, $ as $el } from "./index-CaD4RONs.js";
|
|
||||||
const _UserSelectionScreen = class _UserSelectionScreen {
|
|
||||||
show(users, user) {
|
|
||||||
return __async(this, null, function* () {
|
|
||||||
const userSelection = document.getElementById("comfy-user-selection");
|
|
||||||
userSelection.style.display = "";
|
|
||||||
return new Promise((resolve) => {
|
|
||||||
const input = userSelection.getElementsByTagName("input")[0];
|
|
||||||
const select = userSelection.getElementsByTagName("select")[0];
|
|
||||||
const inputSection = input.closest("section");
|
|
||||||
const selectSection = select.closest("section");
|
|
||||||
const form = userSelection.getElementsByTagName("form")[0];
|
|
||||||
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
|
|
||||||
const button = userSelection.getElementsByClassName(
|
|
||||||
"comfy-user-button-next"
|
|
||||||
)[0];
|
|
||||||
let inputActive = null;
|
|
||||||
input.addEventListener("focus", () => {
|
|
||||||
inputSection.classList.add("selected");
|
|
||||||
selectSection.classList.remove("selected");
|
|
||||||
inputActive = true;
|
|
||||||
});
|
|
||||||
select.addEventListener("focus", () => {
|
|
||||||
inputSection.classList.remove("selected");
|
|
||||||
selectSection.classList.add("selected");
|
|
||||||
inputActive = false;
|
|
||||||
select.style.color = "";
|
|
||||||
});
|
|
||||||
select.addEventListener("blur", () => {
|
|
||||||
if (!select.value) {
|
|
||||||
select.style.color = "var(--descrip-text)";
|
|
||||||
}
|
|
||||||
});
|
|
||||||
form.addEventListener("submit", (e) => __async(this, null, function* () {
|
|
||||||
var _a, _b, _c;
|
|
||||||
e.preventDefault();
|
|
||||||
if (inputActive == null) {
|
|
||||||
error.textContent = "Please enter a username or select an existing user.";
|
|
||||||
} else if (inputActive) {
|
|
||||||
const username = input.value.trim();
|
|
||||||
if (!username) {
|
|
||||||
error.textContent = "Please enter a username.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
input.disabled = select.disabled = // @ts-expect-error
|
|
||||||
input.readonly = // @ts-expect-error
|
|
||||||
select.readonly = true;
|
|
||||||
const spinner = createSpinner();
|
|
||||||
button.prepend(spinner);
|
|
||||||
try {
|
|
||||||
const resp = yield api.createUser(username);
|
|
||||||
if (resp.status >= 300) {
|
|
||||||
let message = "Error creating user: " + resp.status + " " + resp.statusText;
|
|
||||||
try {
|
|
||||||
const res = yield resp.json();
|
|
||||||
if (res.error) {
|
|
||||||
message = res.error;
|
|
||||||
}
|
|
||||||
} catch (error2) {
|
|
||||||
}
|
|
||||||
throw new Error(message);
|
|
||||||
}
|
|
||||||
resolve({ username, userId: yield resp.json(), created: true });
|
|
||||||
} catch (err) {
|
|
||||||
spinner.remove();
|
|
||||||
error.textContent = (_c = (_b = (_a = err.message) != null ? _a : err.statusText) != null ? _b : err) != null ? _c : "An unknown error occurred.";
|
|
||||||
input.disabled = select.disabled = // @ts-expect-error
|
|
||||||
input.readonly = // @ts-expect-error
|
|
||||||
select.readonly = false;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} else if (!select.value) {
|
|
||||||
error.textContent = "Please select an existing user.";
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
resolve({
|
|
||||||
username: users[select.value],
|
|
||||||
userId: select.value,
|
|
||||||
created: false
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
if (user) {
|
|
||||||
const name = localStorage["Comfy.userName"];
|
|
||||||
if (name) {
|
|
||||||
input.value = name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (input.value) {
|
|
||||||
input.focus();
|
|
||||||
}
|
|
||||||
const userIds = Object.keys(users != null ? users : {});
|
|
||||||
if (userIds.length) {
|
|
||||||
for (const u of userIds) {
|
|
||||||
$el("option", { textContent: users[u], value: u, parent: select });
|
|
||||||
}
|
|
||||||
select.style.color = "var(--descrip-text)";
|
|
||||||
if (select.value) {
|
|
||||||
select.focus();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
userSelection.classList.add("no-users");
|
|
||||||
input.focus();
|
|
||||||
}
|
|
||||||
}).then((r) => {
|
|
||||||
userSelection.remove();
|
|
||||||
return r;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
__name(_UserSelectionScreen, "UserSelectionScreen");
|
|
||||||
let UserSelectionScreen = _UserSelectionScreen;
|
|
||||||
window.comfyAPI = window.comfyAPI || {};
|
|
||||||
window.comfyAPI.userSelection = window.comfyAPI.userSelection || {};
|
|
||||||
window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
|
||||||
export {
|
|
||||||
UserSelectionScreen
|
|
||||||
};
|
|
||||||
//# sourceMappingURL=userSelection-GRU1gtOt.js.map
|
|
||||||
1
web/assets/userSelection-GRU1gtOt.js.map
generated
vendored
1
web/assets/userSelection-GRU1gtOt.js.map
generated
vendored
File diff suppressed because one or more lines are too long
4
web/index.html
vendored
4
web/index.html
vendored
@@ -14,8 +14,8 @@
|
|||||||
</style> -->
|
</style> -->
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<link rel="stylesheet" type="text/css" href="user.css" />
|
<link rel="stylesheet" type="text/css" href="user.css" />
|
||||||
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
||||||
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script>
|
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script>
|
||||||
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css">
|
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css">
|
||||||
|
|||||||
2
web/materialdesignicons.min.css
vendored
2
web/materialdesignicons.min.css
vendored
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user