Compare commits

...

39 Commits

Author SHA1 Message Date
Yoland Y
1d47ec38d8 Set torch version to be 2.3.1 for v0.0.3 2024-07-26 18:54:29 -07:00
comfyanonymous
f87810cd3e Let tokenizers return weights to be stored in the saved checkpoint. 2024-07-25 10:52:09 -04:00
comfyanonymous
10c919f4c7 Make it possible to load tokenizer data from checkpoints. 2024-07-24 16:43:53 -04:00
comfyanonymous
ce80e69fb8 Avoid loading the dll when it's not necessary. 2024-07-24 13:50:34 -04:00
comfyanonymous
19944ad252 Add code to fix issues with new pytorch version on the standalone. 2024-07-24 12:49:29 -04:00
comfyanonymous
10b43ceea5 Remove duplicate code. 2024-07-24 01:12:59 -04:00
comfyanonymous
0a4c49c57c Support MT5. 2024-07-23 15:35:28 -04:00
comfyanonymous
88ed893034 Allow SPieceTokenizer to load model from a byte string. 2024-07-23 14:17:42 -04:00
comfyanonymous
334ba48cea More generic unet prefix detection code. 2024-07-23 14:13:32 -04:00
comfyanonymous
14764aa2e2 Rename LLAMATokenizer to SPieceTokenizer. 2024-07-22 12:21:45 -04:00
comfyanonymous
b2c995f623 "auto" type is only relevant to the SetUnionControlNetType node. 2024-07-22 11:30:38 -04:00
Chenlei Hu
4151fbfa8a Add error message on union controlnet (#4081) 2024-07-22 11:27:32 -04:00
Chenlei Hu
6045ed31f8 Supress frontend exception on unhandled message type (#4078)
* Supress frontend exception on unhandled message type

* nit
2024-07-21 21:15:01 -04:00
comfyanonymous
f836e69346 Fix bug with SaveAudio node with --gpu-only 2024-07-21 16:16:45 -04:00
Chenlei Hu
5b69cfe7c3 Add timestamp to execution messages (#4076)
* Add timestamp to execution messages

* Add execution_end message

* Rename to execution_success
2024-07-21 15:29:10 -04:00
comfyanonymous
95fa9545f1 Only append zero to noise schedule if last sigma isn't zero. 2024-07-20 12:37:30 -04:00
Greg Wainer
11b74147ee Fix/webp exif little endian (#4061)
* Fix for isLittleEndian flag in parseExifData.

* Add break after reading first exif chunk in getWebpMetadata.
2024-07-19 18:39:04 -04:00
comfyanonymous
6ab8cad22e Implement beta sampling scheduler.
It is based on: https://arxiv.org/abs/2407.12173

Add "beta" to the list of schedulers and the BetaSamplingScheduler node.
2024-07-19 18:05:09 -04:00
bymyself
011b11d8d7 LoadAudio restores file value from workflow (#4043)
* LoadAudio restores file value from workflow

* use onAfterGraphConfigured

* Don't use anonnymous function
2024-07-18 21:59:18 -04:00
comfyanonymous
ff6ca2a892 Move PAG to model_patches/unet section.
Move other unet model_patches nodes to model_patches/unet section.
2024-07-18 17:22:51 -04:00
bymyself
374e093e09 Disable audio widget trying to get previews (#4044) 2024-07-17 16:11:10 -04:00
喵哩个咪
855789403b support clip-vit-large-patch14-336 (#4042)
* support clip-vit-large-patch14-336

* support clip-vit-large-patch14-336
2024-07-17 13:12:50 -04:00
comfyanonymous
6f7869f365 Get clip vision image size from config. 2024-07-17 13:05:38 -04:00
comfyanonymous
281ad42df4 Fix lowvram union controlnet bug. 2024-07-17 10:16:31 -04:00
Chenlei Hu
1cde6b2eff Disallow use of eval with pylint (#4033) 2024-07-16 21:15:08 -04:00
Thomas Ward
c5a48b15bd Make default hash lib configurable without code changes via CLI argument (#3947)
* cli_args: Add --duplicate-check-hash-function.

* server.py: compare_image_hash configurable hash function

Uses an argument added in cli_args to specify the type of hashing to default to for duplicate hash checking.  Uses an `eval()` to identify the specific hashlib class to utilize, but ultimately safely operates because we have specific options and only those options/choices in the arg parser.  So we don't have any unsafe input there.

* Add hasher() to node_helpers

* hashlib selection moved to node_helpers

* default-hashing-function instead of dupe checking hasher

This makes a default-hashing-function option instead of previous selected option.

* Use args.default_hashing_function

* Use safer handling for node_helpers.hasher()

Uses a safer handling method than `eval` to evaluate default hashing function.

* Stray parentheses are evil.

* Indentation fix.

Somehow when I hit save I didn't notice I missed a space to make indentation work proper.  Oops!
2024-07-16 18:27:09 -04:00
Chenlei Hu
f2298799ba Fix annotation (#4035) 2024-07-16 18:20:39 -04:00
comfyanonymous
60383f3b64 Move controlnet nodes to conditioning/controlnet. 2024-07-16 17:08:25 -04:00
comfyanonymous
8270c62530 Add SetUnionControlNetType to set the type of the union controlnet model. 2024-07-16 17:04:53 -04:00
comfyanonymous
821f93872e Allow model sampling to set number of timesteps. 2024-07-16 15:18:40 -04:00
comfyanonymous
e1630391d6 Allow version names like v0.0.1 for the FrontendManager. 2024-07-16 11:29:38 -04:00
Chenlei Hu
99458e8aca Add FrontendManager to manage non-default front-end impl (#3897)
* Add frontend manager

* Add tests

* nit

* Add unit test to github CI

* Fix path

* nit

* ignore

* Add logging

* Install test deps

* Remove 'stable' keyword support

* Update test

* Add web-root arg

* Rename web-root to front-end-root

* Add test on non-exist version number

* Use repo owner/name to replace hard coded provider list

* Inline cmd args

* nit

* Fix unit test
2024-07-16 11:26:11 -04:00
comfyanonymous
33346fd9b8 Fix bug with custom nodes on other drives. 2024-07-15 20:38:26 -04:00
comfyanonymous
136c93cb47 Fix bug with workflow not registering change.
There was an issue when only the class type of a node changed with all the
inputs staying the same.
2024-07-15 20:01:49 -04:00
comfyanonymous
1305fb294c Refactor: Move some code to the comfy/text_encoders folder. 2024-07-15 17:36:24 -04:00
comfyanonymous
7914c47d5a Quick fix for the promax controlnet. 2024-07-14 10:07:36 -04:00
pythongosssss
79547efb65 New menu fixes - fix send to workflow (#3909)
* Fix send to workflow
Fix center align of close workflow dialog
Better support for elements around canvas

* More resilent to extra elements added to body
2024-07-14 02:04:40 -04:00
comfyanonymous
a3dffc447a Support AuraFlow Lora and loading model weights in diffusers format.
You can load model weights in diffusers format using the UNETLoader node.
2024-07-13 13:51:40 -04:00
comfyanonymous
ce2473bb01 Add link to AuraFlow example in Readme. 2024-07-12 15:25:07 -04:00
59 changed files with 883 additions and 213 deletions

23
.github/workflows/pylint.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Python Linting
on: [push, pull_request]
jobs:
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.x
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")

View File

@@ -24,3 +24,7 @@ jobs:
npm run test:generate npm run test:generate
npm test -- --verbose npm test -- --verbose
working-directory: ./tests-ui working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit

1
.gitignore vendored
View File

@@ -18,3 +18,4 @@ venv/
/tests-ui/data/object_info.json /tests-ui/data/object_info.json
/user/ /user/
*.log *.log
web_custom_versions/

3
.pylintrc Normal file
View File

@@ -0,0 +1,3 @@
[MESSAGES CONTROL]
disable=all
enable=eval-used

View File

@@ -32,6 +32,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews) - Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.

0
app/__init__.py Normal file
View File

188
app/frontend_management.py Normal file
View File

@@ -0,0 +1,188 @@
from __future__ import annotations
import argparse
import logging
import os
import re
import tempfile
import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict
import requests
from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING
REQUEST_TIMEOUT = 10 # seconds
class Asset(TypedDict):
url: str
class Release(TypedDict):
id: int
tag_name: str
name: str
prerelease: bool
created_at: str
published_at: str
body: str
assets: NotRequired[list[Asset]]
@dataclass
class FrontEndProvider:
owner: str
repo: str
@property
def folder_name(self) -> str:
return f"{self.owner}_{self.repo}"
@property
def release_url(self) -> str:
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
@cached_property
def all_releases(self) -> list[Release]:
releases = []
api_url = self.release_url
while api_url:
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
releases.extend(response.json())
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
if "next" in response.links:
api_url = response.links["next"]["url"]
else:
api_url = None
return releases
@cached_property
def latest_release(self) -> Release:
latest_release_url = f"{self.release_url}/latest"
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()
def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
return release
raise ValueError(f"Version {version} not found in releases")
def download_release_asset_zip(release: Release, destination_path: str) -> None:
"""Download dist.zip from github release."""
asset_url = None
for asset in release.get("assets", []):
if asset["name"] == "dist.zip":
asset_url = asset["url"]
break
if not asset_url:
raise ValueError("dist.zip not found in the release assets")
# Use a temporary file to download the zip content
with tempfile.TemporaryFile() as tmp_file:
headers = {"Accept": "application/octet-stream"}
response = requests.get(
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
)
response.raise_for_status() # Ensure we got a successful response
# Write the content to the temporary file
tmp_file.write(response.content)
# Go back to the beginning of the temporary file
tmp_file.seek(0)
# Extract the zip file content to the destination path
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
zip_ref.extractall(destination_path)
class FrontendManager:
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Args:
value (str): The version string to parse.
Returns:
tuple[str, str]: A tuple containing provider name and version.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v")
web_root = str(
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
return web_root
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Args:
version_string (str): The version string to initialize the frontend with.
Returns:
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH

View File

@@ -13,6 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists from ..ldm.util import exists
from .control_types import UNION_CONTROLNET_TYPES
from collections import OrderedDict from collections import OrderedDict
import comfy.ops import comfy.ops
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@@ -92,7 +93,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None, transformer_depth_middle=None,
transformer_depth_output=None, transformer_depth_output=None,
attn_precision=None, attn_precision=None,
union_controlnet=False, union_controlnet_num_control_type=None,
device=None, device=None,
operations=comfy.ops.disable_weight_init, operations=comfy.ops.disable_weight_init,
**kwargs, **kwargs,
@@ -320,8 +321,8 @@ class ControlNet(nn.Module):
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch self._feature_size += ch
if union_controlnet: if union_controlnet_num_control_type is not None:
self.num_control_type = 6 self.num_control_type = union_controlnet_num_control_type
num_trans_channel = 320 num_trans_channel = 320
num_trans_head = 8 num_trans_head = 8
num_trans_layer = 1 num_trans_layer = 1
@@ -361,7 +362,7 @@ class ControlNet(nn.Module):
controlnet_cond = self.input_hint_block(hint[idx], emb, context) controlnet_cond = self.input_hint_block(hint[idx], emb, context)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
if idx < len(control_type): if idx < len(control_type):
feat_seq += self.task_embedding[control_type[idx]] feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
inputs.append(feat_seq.unsqueeze(1)) inputs.append(feat_seq.unsqueeze(1))
condition_list.append(controlnet_cond) condition_list.append(controlnet_cond)
@@ -390,6 +391,18 @@ class ControlNet(nn.Module):
if self.control_add_embedding is not None: #Union Controlnet if self.control_add_embedding is not None: #Union Controlnet
control_type = kwargs.get("control_type", []) control_type = kwargs.get("control_type", [])
if any([c >= self.num_control_type for c in control_type]):
max_type = max(control_type)
max_type_name = {
v: k for k, v in UNION_CONTROLNET_TYPES.items()
}[max_type]
raise ValueError(
f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
f"({self.num_control_type}) supported.\n" +
"Please consider using the ProMax ControlNet Union model.\n" +
"https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
)
emb += self.control_add_embedding(control_type, emb.dtype, emb.device) emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
if len(control_type) > 0: if len(control_type) > 0:
if len(hint.shape) < 5: if len(hint.shape) < 5:

View File

@@ -0,0 +1,10 @@
UNION_CONTROLNET_TYPES = {
"openpose": 0,
"depth": 1,
"hed/pidi/scribble/ted": 2,
"canny/lineart/anime_lineart/mlsd": 3,
"normal": 4,
"segment": 5,
"tile": 6,
"repaint": 7,
}

View File

@@ -1,7 +1,10 @@
import argparse import argparse
import enum import enum
import os
from typing import Optional
import comfy.options import comfy.options
class EnumAction(argparse.Action): class EnumAction(argparse.Action):
""" """
Argparse action for handling Enums Argparse action for handling Enums
@@ -109,6 +112,7 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
@@ -124,6 +128,38 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
parser.add_argument(
"--front-end-version",
type=str,
default=DEFAULT_VERSION_STRING,
help="""
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
download available frontend implementations from GitHub releases.
The version string should be in the format of:
[repoOwner]/[repoName]@[version]
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
""",
)
def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory."""
if path is None:
return None
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
return path
parser.add_argument(
"--front-end-root",
type=is_valid_directory,
default=None,
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()

View File

@@ -34,6 +34,7 @@ class ClipVisionModel():
with open(json_config) as f: with open(json_config) as f:
config = json.load(f) config = json.load(f)
self.image_size = config.get("image_size", 224)
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -50,7 +51,7 @@ class ClipVisionModel():
def encode_image(self, image): def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)).float() pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2) out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output() outputs = Output()
@@ -93,7 +94,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
else: else:
return None return None

View File

@@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"torch_dtype": "float32"
}

View File

@@ -45,6 +45,7 @@ class ControlBase:
self.timestep_range = None self.timestep_range = None
self.compression_ratio = 8 self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact' self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
if device is None: if device is None:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
@@ -90,6 +91,7 @@ class ControlBase:
c.compression_ratio = self.compression_ratio c.compression_ratio = self.compression_ratio
c.upscale_algorithm = self.upscale_algorithm c.upscale_algorithm = self.upscale_algorithm
c.latent_format = self.latent_format c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy()
c.vae = self.vae c.vae = self.vae
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
@@ -135,6 +137,10 @@ class ControlBase:
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
return out return out
def set_extra_arg(self, argument, value=None):
self.extra_args[argument] = value
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
@@ -191,7 +197,7 @@ class ControlNet(ControlBase):
timestep = self.model_sampling_current.timestep(t) timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
return self.control_merge(control, control_prev, output_dtype) return self.control_merge(control, control_prev, output_dtype)
def copy(self): def copy(self):
@@ -414,7 +420,7 @@ def load_controlnet(ckpt_path, model=None):
new_sd[diffusers_keys[k]] = controlnet_data.pop(k) new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
controlnet_config["union_controlnet"] = True controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
for k in list(controlnet_data.keys()): for k in list(controlnet_data.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
new_sd[new_k] = controlnet_data.pop(k) new_sd[new_k] = controlnet_data.pop(k)

View File

@@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .. import attention from .. import attention
from einops import rearrange, repeat from einops import rearrange, repeat
from .util import timestep_embedding
def default(x, y): def default(x, y):
if x is not None: if x is not None:
@@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module):
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
if torch.is_floating_point(t):
embedding = embedding.to(dtype=t.dtype)
return embedding
def forward(self, t, dtype, **kwargs): def forward(self, t, dtype, **kwargs):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq) t_emb = self.mlp(t_freq)
return t_emb return t_emb

View File

@@ -274,4 +274,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
key_map[key_lora] = to key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
return key_map return key_map

View File

@@ -109,6 +109,10 @@ def detect_unet_config(state_dict, key_prefix):
unet_config = {} unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1] unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1] unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
unet_config["n_layers"] = double_layers + single_layers
return unet_config return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
@@ -257,13 +261,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return model_config return model_config
def unet_prefix_from_state_dict(state_dict): def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models candidates = ["model.diffusion_model.", #ldm/sgm models
unet_key_prefix = "model.model." "model.model.", #audio models
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow ]
unet_key_prefix = "model." counts = {k: 0 for k in candidates}
for k in state_dict:
for c in candidates:
if k.startswith(c):
counts[c] += 1
break
top = max(counts, key=counts.get)
if counts[top] > 5:
return top
else: else:
unet_key_prefix = "model.diffusion_model." return "model." #aura flow and others
return unet_key_prefix
def convert_config(unet_config): def convert_config(unet_config):
new_config = unet_config.copy() new_config = unet_config.copy()
@@ -450,37 +463,45 @@ def model_config_from_diffusers_unet(state_dict):
return None return None
def convert_diffusers_mmdit(state_dict, output_prefix=""): def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') out_sd = {}
if num_blocks > 0:
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
out_sd = {}
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
for k in sd_map: elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
weight = state_dict.get(k, None) num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
if weight is not None: num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
t = sd_map[k] sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else:
return None
if not isinstance(t, str): for k in sd_map:
if len(t) > 2: weight = state_dict.get(k, None)
fun = t[2] if weight is not None:
else: t = sd_map[k]
fun = lambda a: a
offset = t[1]
if offset is not None:
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
w = old_weight.narrow(offset[0], offset[1], offset[2]) if not isinstance(t, str):
else: if len(t) > 2:
old_weight = weight fun = t[2]
w = weight
w[:] = fun(weight)
t = t[0]
out_sd[t] = old_weight
else: else:
out_sd[t] = weight fun = lambda a: a
state_dict.pop(k) offset = t[1]
if offset is not None:
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
old_weight = weight
w = weight
w[:] = fun(weight)
t = t[0]
out_sd[t] = old_weight
else:
out_sd[t] = weight
state_dict.pop(k)
return out_sd return out_sd

View File

@@ -59,8 +59,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
beta_schedule = sampling_settings.get("beta_schedule", "linear") beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085) linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012) linear_end = sampling_settings.get("linear_end", 0.012)
timesteps = sampling_settings.get("timesteps", 1000)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
self.sigma_data = 1.0 self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,

View File

@@ -6,6 +6,8 @@ from comfy import model_management
import math import math
import logging import logging
import comfy.sampler_helpers import comfy.sampler_helpers
import scipy
import numpy
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
dims = tuple(x_in.shape[2:]) dims = tuple(x_in.shape[2:])
@@ -311,13 +313,18 @@ def simple_scheduler(model_sampling, steps):
def ddim_scheduler(model_sampling, steps): def ddim_scheduler(model_sampling, steps):
s = model_sampling s = model_sampling
sigs = [] sigs = []
ss = max(len(s.sigmas) // steps, 1)
x = 1 x = 1
if math.isclose(float(s.sigmas[x]), 0, abs_tol=0.00001):
steps += 1
sigs = []
else:
sigs = [0.0]
ss = max(len(s.sigmas) // steps, 1)
while x < len(s.sigmas): while x < len(s.sigmas):
sigs += [float(s.sigmas[x])] sigs += [float(s.sigmas[x])]
x += ss x += ss
sigs = sigs[::-1] sigs = sigs[::-1]
sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def normal_scheduler(model_sampling, steps, sgm=False, floor=False): def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
@@ -325,15 +332,34 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
start = s.timestep(s.sigma_max) start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min) end = s.timestep(s.sigma_min)
append_zero = True
if sgm: if sgm:
timesteps = torch.linspace(start, end, steps + 1)[:-1] timesteps = torch.linspace(start, end, steps + 1)[:-1]
else: else:
if math.isclose(float(s.sigma(end)), 0, abs_tol=0.00001):
steps += 1
append_zero = False
timesteps = torch.linspace(start, end, steps) timesteps = torch.linspace(start, end, steps)
sigs = [] sigs = []
for x in range(len(timesteps)): for x in range(len(timesteps)):
ts = timesteps[x] ts = timesteps[x]
sigs.append(s.sigma(ts)) sigs.append(float(s.sigma(ts)))
if append_zero:
sigs += [0.0]
return torch.FloatTensor(sigs)
# Implemented based on: https://arxiv.org/abs/2407.12173
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
total_timesteps = (len(model_sampling.sigmas) - 1)
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
sigs = []
for t in ts:
sigs += [float(model_sampling.sigmas[int(t)])]
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
@@ -703,7 +729,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas(model_sampling, scheduler_name, steps): def calculate_sigmas(model_sampling, scheduler_name, steps):
@@ -719,6 +745,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = ddim_scheduler(model_sampling, steps) sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model_sampling, steps, sgm=True) sigmas = normal_scheduler(model_sampling, steps, sgm=True)
elif scheduler_name == "beta":
sigmas = beta_scheduler(model_sampling, steps)
else: else:
logging.error("error invalid scheduler {}".format(scheduler_name)) logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas return sigmas

View File

@@ -19,8 +19,8 @@ from . import model_detection
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import sd3_clip import comfy.text_encoders.sd3_clip
from . import sa_t5 import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
import comfy.model_patcher import comfy.model_patcher
@@ -60,7 +60,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP: class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False): def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
if no_init: if no_init:
return return
params = target.params.copy() params = target.params.copy()
@@ -79,7 +79,7 @@ class CLIP:
if not model_management.supports_cast(load_device, dt): if not model_management.supports_cast(load_device, dt):
load_device = offload_device load_device = offload_device
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device)) logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
@@ -135,7 +135,11 @@ class CLIP:
return self.cond_stage_model.load_sd(sd) return self.cond_stage_model.load_sd(sd)
def get_sd(self): def get_sd(self):
return self.cond_stage_model.state_dict() sd_clip = self.cond_stage_model.state_dict()
sd_tokenizer = self.tokenizer.state_dict()
for k in sd_tokenizer:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self): def load_model(self):
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)
@@ -414,27 +418,27 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype dtype_t5 = weight.dtype
if weight.shape[-1] == 4096: if weight.shape[-1] == 4096:
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif weight.shape[-1] == 2048: elif weight.shape[-1] == 2048:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]: elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = sa_t5.SAT5Model clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = sa_t5.SAT5Tokenizer clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else: else:
clip_target.clip = sd1_clip.SD1ClipModel clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2: elif len(clip_data) == 2:
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3: elif len(clip_data) == 3:
clip_target.clip = sd3_clip.SD3ClipModel clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory) clip = CLIP(clip_target, embedding_directory=embedding_directory)
for c in clip_data: for c in clip_data:
@@ -520,7 +524,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if clip_target is not None: if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd) clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0: if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory) clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
m, u = clip.load_sd(clip_sd, full_model=True) m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0: if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
@@ -562,26 +566,25 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
if model_config is not None: if model_config is not None:
new_sd = sd new_sd = sd
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 else:
new_sd = model_detection.convert_diffusers_mmdit(sd, "") new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is None: if new_sd is not None: #diffusers mmdit
return None model_config = model_detection.model_config_from_unet(new_sd, "")
model_config = model_detection.model_config_from_unet(new_sd, "") if model_config is None:
if model_config is None: return None
return None else: #diffusers unet
else: #diffusers model_config = model_detection.model_config_from_diffusers_unet(sd)
model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None:
if model_config is None: return None
return None
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
new_sd = {} new_sd = {}
for k in diffusers_keys: for k in diffusers_keys:
if k in sd: if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k) new_sd[diffusers_keys[k]] = sd.pop(k)
else: else:
logging.warning("{} {}".format(diffusers_keys[k], k)) logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)

View File

@@ -386,7 +386,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
@@ -519,12 +519,14 @@ class SDTokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
def state_dict(self):
return {}
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {} out = {}
@@ -534,6 +536,8 @@ class SD1Tokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair) return getattr(self, self.clip).untokenize(token_weight_pair)
def state_dict(self):
return {}
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):

View File

@@ -11,12 +11,12 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
class SD2Tokenizer(sd1_clip.SD1Tokenizer): class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel): class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, **kwargs):

View File

@@ -16,12 +16,12 @@ class SDXLClipG(sd1_clip.SDClipModel):
return super().load_sd(sd) return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SDTokenizer): class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class SDXLTokenizer: class SDXLTokenizer:
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
@@ -34,6 +34,9 @@ class SDXLTokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
@@ -68,12 +71,12 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel): class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):

View File

@@ -5,8 +5,8 @@ from . import utils
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import sd3_clip import comfy.text_encoders.sd3_clip
from . import sa_t5 import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
from . import supported_models_base from . import supported_models_base
@@ -524,7 +524,7 @@ class SD3(supported_models_base.BASE):
t5 = True t5 = True
dtype_t5 = state_dict[t5_key].dtype dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5)) return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
class StableAudio(supported_models_base.BASE): class StableAudio(supported_models_base.BASE):
unet_config = { unet_config = {
@@ -555,7 +555,7 @@ class StableAudio(supported_models_base.BASE):
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model) return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model)
class AuraFlow(supported_models_base.BASE): class AuraFlow(supported_models_base.BASE):
unet_config = { unet_config = {

View File

@@ -1,21 +1,21 @@
from comfy import sd1_clip from comfy import sd1_clip
from .llama_tokenizer import LLAMATokenizer from .spiece_tokenizer import SPieceTokenizer
import comfy.t5 import comfy.text_encoders.t5
import os import os
class PT5XlModel(sd1_clip.SDClipModel): class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class PT5XlTokenizer(sd1_clip.SDTokenizer): class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model") tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1) super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
class AuraT5Model(sd1_clip.SD1ClipModel): class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, **kwargs):

View File

@@ -1,22 +0,0 @@
import os
class LLAMATokenizer:
@staticmethod
def from_pretrained(path):
return LLAMATokenizer(path)
def __init__(self, tokenizer_path):
import sentencepiece
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
self.end = self.tokenizer.eos_id()
def get_vocab(self):
out = {}
for i in range(self.tokenizer.get_piece_size()):
out[self.tokenizer.id_to_piece(i)] = i
return out
def __call__(self, string):
out = self.tokenizer.encode(string)
out += [self.end]
return {"input_ids": out}

View File

@@ -1,21 +1,21 @@
from comfy import sd1_clip from comfy import sd1_clip
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
import comfy.t5 import comfy.text_encoders.t5
import os import os
class T5BaseModel(sd1_clip.SDClipModel): class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer): class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer): class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel): class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, **kwargs):

View File

@@ -1,7 +1,7 @@
from comfy import sd1_clip from comfy import sd1_clip
from comfy import sdxl_clip from comfy import sdxl_clip
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
import comfy.t5 import comfy.text_encoders.t5
import torch import torch
import os import os
import comfy.model_management import comfy.model_management
@@ -10,25 +10,16 @@ import logging
class T5XXLModel(sd1_clip.SDClipModel): class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
class SDT5XXLModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
class SD3Tokenizer: class SD3Tokenizer:
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@@ -43,6 +34,9 @@ class SD3Tokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
class SD3ClipModel(torch.nn.Module): class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
super().__init__() super().__init__()

View File

@@ -0,0 +1,29 @@
import os
import torch
class SPieceTokenizer:
add_eos = True
@staticmethod
def from_pretrained(path):
return SPieceTokenizer(path)
def __init__(self, tokenizer_path):
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
if isinstance(tokenizer_path, bytes):
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
else:
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
def get_vocab(self):
out = {}
for i in range(self.tokenizer.get_piece_size()):
out[self.tokenizer.id_to_piece(i)] = i
return out
def __call__(self, string):
out = self.tokenizer.encode(string)
return {"input_ids": out}

View File

@@ -223,7 +223,7 @@ class T5(torch.nn.Module):
self.num_layers = config_dict["num_layers"] self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"] model_dim = config_dict["d_model"]
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations) self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.dtype = dtype self.dtype = dtype
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device) self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)

View File

@@ -332,6 +332,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
return key_map return key_map
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("n_double_layers", 0)
n_layers = mmdit_config.get("n_layers", 0)
key_map = {}
for i in range(n_layers):
if i < n_double_layers:
index = i
prefix_from = "joint_transformer_blocks"
prefix_to = "{}double_layers".format(output_prefix)
block_map = {
"attn.to_q.weight": "attn.w2q.weight",
"attn.to_k.weight": "attn.w2k.weight",
"attn.to_v.weight": "attn.w2v.weight",
"attn.to_out.0.weight": "attn.w2o.weight",
"attn.add_q_proj.weight": "attn.w1q.weight",
"attn.add_k_proj.weight": "attn.w1k.weight",
"attn.add_v_proj.weight": "attn.w1v.weight",
"attn.to_add_out.weight": "attn.w1o.weight",
"ff.linear_1.weight": "mlpX.c_fc1.weight",
"ff.linear_2.weight": "mlpX.c_fc2.weight",
"ff.out_projection.weight": "mlpX.c_proj.weight",
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
"norm1.linear.weight": "modX.1.weight",
"norm1_context.linear.weight": "modC.1.weight",
}
else:
index = i - n_double_layers
prefix_from = "single_transformer_blocks"
prefix_to = "{}single_layers".format(output_prefix)
block_map = {
"attn.to_q.weight": "attn.w1q.weight",
"attn.to_k.weight": "attn.w1k.weight",
"attn.to_v.weight": "attn.w1v.weight",
"attn.to_out.0.weight": "attn.w1o.weight",
"norm1.linear.weight": "modCX.1.weight",
"ff.linear_1.weight": "mlp.c_fc1.weight",
"ff.linear_2.weight": "mlp.c_fc2.weight",
"ff.out_projection.weight": "mlp.c_proj.weight"
}
for k in block_map:
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
MAP_BASIC = {
("positional_encoding", "pos_embed.pos_embed"),
("register_tokens", "register_tokens"),
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
("cond_seq_linear.weight", "context_embedder.weight"),
("init_x_linear.weight", "pos_embed.proj.weight"),
("init_x_linear.bias", "pos_embed.proj.bias"),
("final_linear.weight", "proj_out.weight"),
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
}
for k in MAP_BASIC:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0): def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size: if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size) return tensor.narrow(dim, 0, batch_size)

View File

@@ -147,7 +147,7 @@ class SaveAudio:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x]) metadata[x] = json.dumps(extra_pnginfo[x])
for (batch_number, waveform) in enumerate(audio["waveform"]): for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.flac" file = f"{filename_with_batch_num}_{counter:05}_.flac"

View File

@@ -0,0 +1,27 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
class SetUnionControlNetType:
@classmethod
def INPUT_TYPES(s):
return {"required": {"control_net": ("CONTROL_NET", ),
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
}}
CATEGORY = "conditioning/controlnet"
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "set_controlnet_type"
def set_controlnet_type(self, control_net, type):
control_net = control_net.copy()
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
if type_number >= 0:
control_net.set_extra_arg("control_type", [type_number])
else:
control_net.set_extra_arg("control_type", [])
return (control_net,)
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
}

View File

@@ -111,6 +111,25 @@ class SDTurboScheduler:
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, ) return (sigmas, )
class BetaSamplingScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
"beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps, alpha, beta):
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
return (sigmas, )
class VPScheduler: class VPScheduler:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@@ -638,6 +657,7 @@ NODE_CLASS_MAPPINGS = {
"ExponentialScheduler": ExponentialScheduler, "ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler, "VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler, "SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect, "KSamplerSelect": KSamplerSelect,
"SamplerEulerAncestral": SamplerEulerAncestral, "SamplerEulerAncestral": SamplerEulerAncestral,

View File

@@ -34,7 +34,7 @@ class FreeU:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "model_patches" CATEGORY = "model_patches/unet"
def patch(self, model, b1, b2, s1, s2): def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
@@ -73,7 +73,7 @@ class FreeU_V2:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "model_patches" CATEGORY = "model_patches/unet"
def patch(self, model, b1, b2, s1, s2): def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]

View File

@@ -32,7 +32,7 @@ class HyperTile:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "model_patches" CATEGORY = "model_patches/unet"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth): def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]

View File

@@ -19,7 +19,7 @@ class PerturbedAttentionGuidance:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches/unet"
def patch(self, model, scale): def patch(self, model, scale):
unet_block = "middle" unet_block = "middle"

View File

@@ -92,7 +92,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}} }}
CATEGORY = "_for_testing/sd3" CATEGORY = "conditioning/controlnet"
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader, "TripleCLIPLoader": TripleCLIPLoader,

View File

@@ -3,6 +3,7 @@ import copy
import logging import logging
import threading import threading
import heapq import heapq
import time
import traceback import traceback
import inspect import inspect
from typing import List, Literal, NamedTuple, Optional from typing import List, Literal, NamedTuple, Optional
@@ -247,6 +248,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
to_delete = True to_delete = True
elif unique_id not in old_prompt: elif unique_id not in old_prompt:
to_delete = True to_delete = True
elif class_type != old_prompt[unique_id]['class_type']:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']: elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
@@ -281,7 +284,11 @@ class PromptExecutor:
self.success = True self.success = True
self.old_prompt = {} self.old_prompt = {}
def add_message(self, event, data, broadcast: bool): def add_message(self, event, data: dict, broadcast: bool):
data = {
**data,
"timestamp": int(time.time() * 1000),
}
self.status_messages.append((event, data)) self.status_messages.append((event, data))
if self.server.client_id is not None or broadcast: if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id) self.server.send_sync(event, data, self.server.client_id)
@@ -392,6 +399,9 @@ class PromptExecutor:
if self.success is not True: if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break break
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
for x in executed: for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])

24
fix_torch.py Normal file
View File

@@ -0,0 +1,24 @@
import importlib.util
import shutil
import os
import ctypes
import logging
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
with open(test_file, 'rb') as f:
contents = f.read()
if b"libomp140.x86_64.dll" not in contents:
break
try:
mydll = ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError as e:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)

View File

@@ -74,6 +74,12 @@ if __name__ == "__main__":
import cuda_malloc import cuda_malloc
if args.windows_standalone_build:
try:
import fix_torch
except:
pass
import comfy.utils import comfy.utils
import yaml import yaml

View File

@@ -1,3 +1,7 @@
import hashlib
from comfy.cli_args import args
from PIL import ImageFile, UnidentifiedImageError from PIL import ImageFile, UnidentifiedImageError
def conditioning_set_values(conditioning, values={}): def conditioning_set_values(conditioning, values={}):
@@ -22,3 +26,12 @@ def pillow(fn, arg):
if prev_value is not None: if prev_value is not None:
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
return x return x
def hasher():
hashfuncs = {
"md5": hashlib.md5,
"sha1": hashlib.sha1,
"sha256": hashlib.sha256,
"sha512": hashlib.sha512
}
return hashfuncs[args.default_hashing_function]

View File

@@ -748,7 +748,7 @@ class ControlNetApply:
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet" FUNCTION = "apply_controlnet"
CATEGORY = "conditioning" CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image, strength): def apply_controlnet(self, conditioning, control_net, image, strength):
if strength == 0: if strength == 0:
@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
RETURN_NAMES = ("positive", "negative") RETURN_NAMES = ("positive", "negative")
FUNCTION = "apply_controlnet" FUNCTION = "apply_controlnet"
CATEGORY = "conditioning" CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None): def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
if strength == 0: if strength == 0:
@@ -1890,29 +1890,29 @@ NODE_DISPLAY_NAME_MAPPINGS = {
EXTENSION_WEB_DIRS = {} EXTENSION_WEB_DIRS = {}
def get_relative_module_name(module_path: str) -> str: def get_module_name(module_path: str) -> str:
""" """
Returns the module name based on the given module path. Returns the module name based on the given module path.
Examples: Examples:
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "custom_nodes.my_custom_node" get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "my_custom_node"
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "custom_nodes.my_custom_node" get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "my_custom_node"
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "custom_nodes.my_custom_node" get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "my_custom_node"
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "custom_nodes.my_custom_node" get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "my_custom_node"
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "custom_nodes.my_custom_node" get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "my_custom_node"
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "custom_nodes.my_custom_node" get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "my_custom_node"
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes.my get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes
Args: Args:
module_path (str): The path of the module. module_path (str): The path of the module.
Returns: Returns:
str: The module name. str: The module name.
""" """
relative_path = os.path.relpath(module_path, folder_paths.base_path) base_path = os.path.basename(module_path)
if os.path.isfile(module_path): if os.path.isfile(module_path):
relative_path = os.path.splitext(relative_path)[0] base_path = os.path.splitext(base_path)[0]
return relative_path.replace(os.sep, '.') return base_path
def load_custom_node(module_path: str, ignore=set()) -> bool: def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
module_name = os.path.basename(module_path) module_name = os.path.basename(module_path)
if os.path.isfile(module_path): if os.path.isfile(module_path):
sp = os.path.splitext(module_path) sp = os.path.splitext(module_path)
@@ -1939,7 +1939,7 @@ def load_custom_node(module_path: str, ignore=set()) -> bool:
for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
if name not in ignore: if name not in ignore:
NODE_CLASS_MAPPINGS[name] = node_cls NODE_CLASS_MAPPINGS[name] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = get_relative_module_name(module_path) node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True return True
@@ -1974,7 +1974,7 @@ def init_external_custom_nodes():
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
if module_path.endswith(".disabled"): continue if module_path.endswith(".disabled"): continue
time_before = time.perf_counter() time_before = time.perf_counter()
success = load_custom_node(module_path, base_node_names) success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
node_import_times.append((time.perf_counter() - time_before, module_path, success)) node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0: if len(node_import_times) > 0:
@@ -2036,11 +2036,12 @@ def init_builtin_extra_nodes():
"nodes_audio.py", "nodes_audio.py",
"nodes_sd3.py", "nodes_sd3.py",
"nodes_gits.py", "nodes_gits.py",
"nodes_controlnet.py",
] ]
import_failed = [] import_failed = []
for node_file in extras_files: for node_file in extras_files:
if not load_custom_node(os.path.join(extras_dir, node_file)): if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
import_failed.append(node_file) import_failed.append(node_file)
return import_failed return import_failed

View File

@@ -1,5 +1,8 @@
[pytest] [pytest]
markers = markers =
inference: mark as inference test (deselect with '-m "not inference"') inference: mark as inference test (deselect with '-m "not inference"')
testpaths = tests testpaths =
tests
tests-unit
addopts = -s addopts = -s
pythonpath = .

View File

@@ -1,7 +1,7 @@
torch torch==2.3.1
torchsde torchsde
torchvision torchvision==0.18.1
torchaudio torchaudio==2.3.1
einops einops
transformers>=4.28.1 transformers>=4.28.1
tokenizers>=0.13.3 tokenizers>=0.13.3
@@ -13,6 +13,7 @@ Pillow
scipy scipy
tqdm tqdm
psutil psutil
numpy<2.0.0
#non essential dependencies: #non essential dependencies:
kornia>=0.7.1 kornia>=0.7.1

View File

@@ -25,9 +25,11 @@ import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import node_helpers
from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2 UNENCODED_PREVIEW_IMAGE = 2
@@ -83,8 +85,12 @@ class PromptServer():
max_upload_size = round(args.max_upload_size * 1024 * 1024) max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
self.web_root = os.path.join(os.path.dirname( self.web_root = (
os.path.realpath(__file__)), "web") FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None
else args.front_end_root
)
logging.info(f"[Prompt Server] web root: {self.web_root}")
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes = routes self.routes = routes
self.last_node_id = None self.last_node_id = None
@@ -156,10 +162,12 @@ class PromptServer():
return type_dir, dir_type return type_dir, dir_type
def compare_image_hash(filepath, image): def compare_image_hash(filepath, image):
hasher = node_helpers.hasher()
# function to compare hashes of two images to see if it already exists, fix to #3465 # function to compare hashes of two images to see if it already exists, fix to #3465
if os.path.exists(filepath): if os.path.exists(filepath):
a = hashlib.sha256() a = hasher()
b = hashlib.sha256() b = hasher()
with open(filepath, "rb") as f: with open(filepath, "rb") as f:
a.update(f.read()) a.update(f.read())
b.update(image.file.read()) b.update(image.file.read())

8
tests-unit/README.md Normal file
View File

@@ -0,0 +1,8 @@
# Pytest Unit Tests
## Install test dependencies
`pip install -r tests-units/requirements.txt`
## Run tests
`pytest tests-units/`

View File

View File

@@ -0,0 +1,100 @@
import argparse
import pytest
from requests.exceptions import HTTPError
from app.frontend_management import (
FrontendManager,
FrontEndProvider,
Release,
)
from comfy.cli_args import DEFAULT_VERSION_STRING
@pytest.fixture
def mock_releases():
return [
Release(
id=1,
tag_name="1.0.0",
name="Release 1.0.0",
prerelease=False,
created_at="2022-01-01T00:00:00Z",
published_at="2022-01-01T00:00:00Z",
body="Release notes for 1.0.0",
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
),
Release(
id=2,
tag_name="2.0.0",
name="Release 2.0.0",
prerelease=False,
created_at="2022-02-01T00:00:00Z",
published_at="2022-02-01T00:00:00Z",
body="Release notes for 2.0.0",
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
),
]
@pytest.fixture
def mock_provider(mock_releases):
provider = FrontEndProvider(
owner="test-owner",
repo="test-repo",
)
provider.all_releases = mock_releases
provider.latest_release = mock_releases[1]
FrontendManager.PROVIDERS = [provider]
return provider
def test_get_release(mock_provider, mock_releases):
version = "1.0.0"
release = mock_provider.get_release(version)
assert release == mock_releases[0]
def test_get_release_latest(mock_provider, mock_releases):
version = "latest"
release = mock_provider.get_release(version)
assert release == mock_releases[1]
def test_get_release_invalid_version(mock_provider):
version = "invalid"
with pytest.raises(ValueError):
mock_provider.get_release(version)
def test_init_frontend_default():
version_string = DEFAULT_VERSION_STRING
frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
def test_init_frontend_invalid_version():
version_string = "test-owner/test-repo@1.100.99"
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
def test_init_frontend_invalid_provider():
version_string = "invalid/invalid@latest"
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
def test_parse_version_string():
version_string = "owner/repo@1.0.0"
repo_owner, repo_name, version = FrontendManager.parse_version_string(
version_string
)
assert repo_owner == "owner"
assert repo_name == "repo"
assert version == "1.0.0"
def test_parse_version_string_invalid():
version_string = "invalid"
with pytest.raises(argparse.ArgumentTypeError):
FrontendManager.parse_version_string(version_string)

View File

@@ -0,0 +1 @@
pytest>=7.8.0

View File

@@ -17,7 +17,6 @@ function getResourceURL(subfolder, filename, type = "input") {
"filename=" + encodeURIComponent(filename), "filename=" + encodeURIComponent(filename),
"type=" + type, "type=" + type,
"subfolder=" + subfolder, "subfolder=" + subfolder,
app.getPreviewFormatParam().substring(1),
app.getRandParam().substring(1) app.getRandParam().substring(1)
].join("&") ].join("&")
@@ -150,6 +149,15 @@ app.registerExtension({
} }
audioWidget.callback = onAudioWidgetUpdate audioWidget.callback = onAudioWidgetUpdate
// Load saved audio file widget values if restoring from workflow
const onGraphConfigured = node.onGraphConfigured;
node.onGraphConfigured = function() {
onGraphConfigured?.apply(this, arguments)
if (audioWidget.value) {
onAudioWidgetUpdate()
}
}
const fileInput = document.createElement("input") const fileInput = document.createElement("input")
fileInput.type = "file" fileInput.type = "file"
fileInput.accept = "audio/*" fileInput.accept = "audio/*"

View File

@@ -136,6 +136,9 @@ class ComfyApi extends EventTarget {
case "execution_start": case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break; break;
case "execution_success":
this.dispatchEvent(new CustomEvent("execution_success", { detail: msg.data }));
break;
case "execution_error": case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break; break;

View File

@@ -49,7 +49,7 @@ export function getPngMetadata(file) {
function parseExifData(exifData) { function parseExifData(exifData) {
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian) // Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949; const isLittleEndian = String.fromCharCode(...exifData.slice(0, 2)) === "II";
// Function to read 16-bit and 32-bit integers from binary data // Function to read 16-bit and 32-bit integers from binary data
function readInt(offset, isLittleEndian, length) { function readInt(offset, isLittleEndian, length) {
@@ -134,6 +134,7 @@ export function getWebpMetadata(file) {
let index = value.indexOf(':'); let index = value.indexOf(':');
txt_chunks[value.slice(0, index)] = value.slice(index + 1); txt_chunks[value.slice(0, index)] = value.slice(index + 1);
} }
break;
} }
offset += 8 + chunk_length; offset += 8 + chunk_length;

View File

@@ -182,6 +182,11 @@ export class ComfyWorkflowsMenu {
* @param {ComfyWorkflow} workflow * @param {ComfyWorkflow} workflow
*/ */
async function sendToWorkflow(img, workflow) { async function sendToWorkflow(img, workflow) {
const openWorkflow = app.workflowManager.openWorkflows.find((w) => w.path === workflow.path);
if (openWorkflow) {
workflow = openWorkflow;
}
await workflow.load(); await workflow.load();
let options = []; let options = [];
const nodes = app.graph.computeExecutionOrder(false); const nodes = app.graph.computeExecutionOrder(false);
@@ -214,7 +219,8 @@ export class ComfyWorkflowsMenu {
nodeType.prototype["getExtraMenuOptions"] = function (_, options) { nodeType.prototype["getExtraMenuOptions"] = function (_, options) {
const r = getExtraMenuOptions?.apply?.(this, arguments); const r = getExtraMenuOptions?.apply?.(this, arguments);
if (app.ui.settings.getSettingValue("Comfy.UseNewMenu", false) === true) { const setting = app.ui.settings.getSettingValue("Comfy.UseNewMenu", false);
if (setting && setting != "Disabled") {
const t = /** @type { {imageIndex?: number, overIndex?: number, imgs: string[]} } */ /** @type {any} */ (this); const t = /** @type { {imageIndex?: number, overIndex?: number, imgs: string[]} } */ /** @type {any} */ (this);
let img; let img;
if (t.imageIndex != null) { if (t.imageIndex != null) {

View File

@@ -41,7 +41,7 @@ body {
background-color: var(--bg-color); background-color: var(--bg-color);
color: var(--fg-color); color: var(--fg-color);
grid-template-columns: auto 1fr auto; grid-template-columns: auto 1fr auto;
grid-template-rows: auto auto 1fr auto; grid-template-rows: auto 1fr auto;
min-height: -webkit-fill-available; min-height: -webkit-fill-available;
max-height: -webkit-fill-available; max-height: -webkit-fill-available;
min-width: -webkit-fill-available; min-width: -webkit-fill-available;
@@ -49,32 +49,37 @@ body {
} }
.comfyui-body-top { .comfyui-body-top {
order: 0; order: -5;
grid-column: 1/-1; grid-column: 1/-1;
z-index: 10; z-index: 10;
display: flex;
flex-direction: column;
} }
.comfyui-body-left { .comfyui-body-left {
order: 1; order: -4;
z-index: 10; z-index: 10;
display: flex;
} }
#graph-canvas { #graph-canvas {
width: 100%; width: 100%;
height: 100%; height: 100%;
order: 2; order: -3;
grid-column: 1/-1;
} }
.comfyui-body-right { .comfyui-body-right {
order: 3; order: -2;
z-index: 10; z-index: 10;
display: flex;
} }
.comfyui-body-bottom { .comfyui-body-bottom {
order: 4; order: -1;
grid-column: 1/-1; grid-column: 1/-1;
z-index: 10; z-index: 10;
display: flex;
flex-direction: column;
} }
.comfy-multiline-input { .comfy-multiline-input {
@@ -408,8 +413,12 @@ dialog::backdrop {
background: rgba(0, 0, 0, 0.5); background: rgba(0, 0, 0, 0.5);
} }
.comfy-dialog.comfyui-dialog { .comfy-dialog.comfyui-dialog.comfy-modal {
top: 0; top: 0;
left: 0;
right: 0;
bottom: 0;
transform: none;
} }
.comfy-dialog.comfy-modal { .comfy-dialog.comfy-modal {