Compare commits

...

39 Commits

Author SHA1 Message Date
comfyanonymous
7689917113 ComfyUI version 0.3.31 2025-05-03 00:34:01 -04:00
comfyanonymous
486ad8fdc5 Fix updater issue with newer portable. (#7917) 2025-05-03 00:28:10 -04:00
Terry Jia
065d855f14 upstream Preview Any from rgthree-comfy (#7815)
* upstream Preview Any from rgthree-comfy

* use IO.ANY
2025-05-02 13:15:54 -04:00
Chenlei Hu
530494588d [BugFix] Update frontend 1.18.6 (#7910) 2025-05-02 13:14:52 -04:00
Kohaku-Blueleaf
2ab9618732 Fix the bugs in OFT/BOFT moule (#7909)
* Correct calculate_weight and load for OFT

* Correct calculate_weight and loading for BOFT
2025-05-02 13:12:37 -04:00
catboxanon
d9a87c1e6a Fix outdated comment about Internet connectivity (#7827) 2025-05-02 05:28:27 -04:00
catboxanon
551fe8dcee Add node to extend sigmas (#7901)
* Add ExpandSigmas node

* Rename, add interpolation functions

Co-authored-by: liesen <liesen.dev@gmail.com>

* Move computed interpolation outside loop

* Add type hints

---------

Co-authored-by: liesen <liesen.dev@gmail.com>
2025-05-02 05:28:05 -04:00
comfyanonymous
ff99861650 Make clipsave work with more TE models. (#7908) 2025-05-02 05:15:32 -04:00
catboxanon
8d0661d0ba Lint instance methods (#7903) 2025-05-01 19:32:04 -04:00
Chenlei Hu
6d32dc049e Update frontend to v1.18 (#7898) 2025-05-01 10:57:54 -04:00
comfyanonymous
aa9d759df3 Switch ltxv to use the pytorch RMSNorm. (#7897) 2025-05-01 06:33:42 -04:00
Christian Byrne
c6c19e9980 fix bug (#7894) 2025-05-01 03:24:32 -04:00
comfyanonymous
08ff5fa08a Cleanup chroma PR. 2025-04-30 20:57:30 -04:00
Silver
4ca3d84277 Support for Chroma - Flux1 Schnell distilled with CFG (#7355)
* Upload files for Chroma Implementation

* Remove trailing whitespace

* trim more trailing whitespace..oops

* remove unused imports

* Add supported_inference_dtypes

* Set min_length to 0 and remove attention_mask=True

* Set min_length to 1

* get_mdulations added from blepping and minor changes

* Add lora conversion if statement in lora.py

* Update supported_models.py

* update model_base.py

* add uptream commits

* set modelType.FLOW, will cause beta scheduler to work properly

* Adjust memory usage factor and remove unnecessary code

* fix mistake

* reduce code duplication

* remove unused imports

* refactor for upstream sync

* sync chroma-support with upstream via syncbranch patch

* Update sd.py

* Add Chroma as option for the OptimalStepsScheduler node
2025-04-30 20:57:00 -04:00
comfyanonymous
39c27a3705 Add updater test to stable release workflow. (#7887) 2025-04-30 14:42:18 -04:00
comfyanonymous
b1c7291569 Test updater in the windows release workflow. (#7886) 2025-04-30 14:18:20 -04:00
comfyanonymous
dbc726f80c Better vace memory estimation. (#7875) 2025-04-29 20:42:00 -04:00
comfyanonymous
7ee96455e2 Bump minimum pyav version to 14.2.0 (#7874) 2025-04-29 20:38:45 -04:00
comfyanonymous
0a66d4b0af Per device stream counters for async offload. (#7873) 2025-04-29 20:28:52 -04:00
Terry Jia
5c5457a4ef support more example folders (#7836)
* support more example folders

* add warning message
2025-04-29 11:28:04 -04:00
Chenlei Hu
45503f6499 Add release process section to README (#7855)
* Add release process section to README

* move

* Update README.md
2025-04-29 06:32:34 -04:00
comfyanonymous
005a91ce2b Latest desktop and portable should work on blackwell. (#7861)
Removed the mention about the cards from the readme.
2025-04-29 06:29:38 -04:00
guill
68f0d35296 Add support for VIDEO as a built-in type (#7844)
* Add basic support for videos as types

This PR adds support for VIDEO as first-class types. In order to avoid
unnecessary costs, VIDEO outputs must implement the `VideoInput` ABC,
but their implementation details can vary. Included are two
implementations of this type which can be returned by other nodes:

* `VideoFromFile` - Created with either a path on disk (as a string) or
  a `io.BytesIO` containing the contents of a file in a supported format
  (like .mp4). This implementation won't actually load the video unless
  necessary. It will also avoid re-encoding when saving if possible.
* `VideoFromComponents` - Created from an image tensor and an optional
  audio tensor.

Currently, only h264 encoded videos in .mp4 containers are supported for
saving, but the plan is to add additional encodings/containers in the
near future (particularly .webm).

* Add optimization to avoid parsing entire video

* Improve type declarations to reduce warnings

* Make sure bytesIO objects can be read many times

* Fix a potential issue when saving long videos

* Fix incorrect type annotation

* Add a `LoadVideo` node to make testing easier

* Refactor new types out of the base comfy folder

I've created a new `comfy_api` top-level module. The intention is that
anything within this folder would be covered by semver-style versioning
that would allow custom nodes to rely on them not introducing breaking
changes.

* Fix linting issue
2025-04-29 05:58:00 -04:00
comfyanonymous
83d04717b6 Support HiDream E1 model. (#7857) 2025-04-28 15:01:15 -04:00
Yoland Yan
7d329771f9 Add moderation level option to OpenAIGPTImage1 node and update api_call method signature (#7804) 2025-04-28 13:59:22 -04:00
chaObserv
c15909bb62 CFG++ for gradient estimation sampler (#7809) 2025-04-28 13:51:35 -04:00
Andrew Kvochko
772b4c5945 ltxv: overwrite existing mask on conditioned frame. (#7845)
This commit overwrites the noise mask on the latent frame that is being
conditioned with keyframe conditioning, setting it to one.
2025-04-28 13:42:04 -04:00
comfyanonymous
5a50c3c7e5 Fix stream priority to support older pytorch. (#7856) 2025-04-28 13:07:21 -04:00
Pam
30159a7fe6 Save v pred zsnr metadata (#7840) 2025-04-28 13:03:21 -04:00
Andrew Kvochko
cb9ac3db58 ltxv: add strength parameter to conditioning. (#7849)
This commit adds strength parameter to the LTXVImgToVideo node.
2025-04-28 12:59:17 -04:00
Benjamin Lu
8115a7895b Add /api/v2/userdata endpoint (#7817)
* Add list_userdata_v2

* nit

* nit

* nit

* nit

* please set me free

* \\\\

* \\\\
2025-04-27 20:06:55 -04:00
comfyanonymous
c8cd7ad795 Use stream for casting if enabled. (#7833) 2025-04-27 05:38:11 -04:00
comfyanonymous
542b4b36b6 Prevent custom nodes from hooking certain functions. (#7825) 2025-04-26 20:52:56 -04:00
comfyanonymous
ac10a0d69e Make loras work with --async-offload (#7824) 2025-04-26 19:56:22 -04:00
comfyanonymous
0dcc75ca54 Add experimental --async-offload lowvram weight offloading. (#7820)
This should speed up the lowvram mode a bit. It currently is only enabled when --async-offload is used but it will be enabled by default in the future if there are no problems.
2025-04-26 16:11:21 -04:00
comfyanonymous
b685b8a4e0 Update portable package workflow to cu128 (#7812) 2025-04-26 04:43:12 -04:00
comfyanonymous
23e39f2ba7 Add a T5TokenizerOptions node to set options for the T5 tokenizer. (#7803) 2025-04-25 19:36:00 -04:00
AustinMroz
78992c4b25 [NodeDef] Add documentation on widgetType (#7768)
* [NodeDef] Add documentation on widgetType

* Document required version for widgetType
2025-04-25 13:35:07 -04:00
comfyanonymous
f935d42d8e Support SimpleTuner lycoris lora format for HiDream. 2025-04-25 03:11:14 -04:00
59 changed files with 1644 additions and 129 deletions

View File

@@ -63,6 +63,11 @@ except:
print("checking out master branch") # noqa: T201 print("checking out master branch") # noqa: T201
branch = repo.lookup_branch('master') branch = repo.lookup_branch('master')
if branch is None: if branch is None:
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
print("pulling.") # noqa: T201
pull(repo)
ref = repo.lookup_reference('refs/remotes/origin/master') ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref) repo.checkout(ref)
branch = repo.lookup_branch('master') branch = repo.lookup_branch('master')

View File

@@ -12,7 +12,7 @@ on:
description: 'CUDA version' description: 'CUDA version'
required: true required: true
type: string type: string
default: "126" default: "128"
python_minor: python_minor:
description: 'Python minor version' description: 'Python minor version'
required: true required: true
@@ -22,7 +22,7 @@ on:
description: 'Python patch version' description: 'Python patch version'
required: true required: true
type: string type: string
default: "9" default: "10"
jobs: jobs:
@@ -91,6 +91,8 @@ jobs:
cd ComfyUI_windows_portable cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
python_embeded/python.exe -s ./update/update.py ComfyUI/
ls ls
- name: Upload binaries to release - name: Upload binaries to release

View File

@@ -17,7 +17,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "126" default: "128"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
@@ -29,7 +29,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "9" default: "10"
# push: # push:
# branches: # branches:
# - master # - master

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "126" default: "128"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "9" default: "10"
# push: # push:
# branches: # branches:
# - master # - master
@@ -88,6 +88,8 @@ jobs:
cd ComfyUI_windows_portable cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
python_embeded/python.exe -s ./update/update.py ComfyUI/
ls ls
- name: Upload binaries to release - name: Upload binaries to release

View File

@@ -49,7 +49,6 @@ Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon,
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/) ## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/). See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models - Image Models
@@ -99,6 +98,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
## Release Process
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0)
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
- Builds a new release using the latest stable core version
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
- Weekly frontend updates are merged into the core repository
- Features are frozen for the upcoming core release
- Development continues for the next release cycle
## Shortcuts ## Shortcuts
| Keybind | Explanation | | Keybind | Explanation |
@@ -149,8 +165,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock If you have trouble extracting it, right click the file -> properties -> unblock
If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643)
#### How do I share models between another UI and ComfyUI? #### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.

View File

@@ -93,16 +93,20 @@ class CustomNodeManager:
def add_routes(self, routes, webapp, loadedModules): def add_routes(self, routes, webapp, loadedModules):
example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
@routes.get("/workflow_templates") @routes.get("/workflow_templates")
async def get_workflow_templates(request): async def get_workflow_templates(request):
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted.""" """Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
files = [
file files = []
for folder in folder_paths.get_folder_paths("custom_nodes")
for file in glob.glob( for folder in folder_paths.get_folder_paths("custom_nodes"):
os.path.join(folder, "*/example_workflows/*.json") for folder_name in example_workflow_folder_names:
) pattern = os.path.join(folder, f"*/{folder_name}/*.json")
] matched_files = glob.glob(pattern)
files.extend(matched_files)
workflow_templates_dict = ( workflow_templates_dict = (
{} {}
) # custom_nodes folder name -> example workflow names ) # custom_nodes folder name -> example workflow names
@@ -118,8 +122,15 @@ class CustomNodeManager:
# Serve workflow templates from custom nodes. # Serve workflow templates from custom nodes.
for module_name, module_dir in loadedModules: for module_name, module_dir in loadedModules:
workflows_dir = os.path.join(module_dir, "example_workflows") for folder_name in example_workflow_folder_names:
workflows_dir = os.path.join(module_dir, folder_name)
if os.path.exists(workflows_dir): if os.path.exists(workflows_dir):
if folder_name != "example_workflows":
logging.warning(
"WARNING: Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
folder_name, module_name)
webapp.add_routes( webapp.add_routes(
[ [
web.static( web.static(

View File

@@ -197,6 +197,112 @@ class UserManager():
return web.json_response(results) return web.json_response(results)
@routes.get("/v2/userdata")
async def list_userdata_v2(request):
"""
List files and directories in a user's data directory.
This endpoint provides a structured listing of contents within a specified
subdirectory of the user's data storage.
Query Parameters:
- path (optional): The relative path within the user's data directory
to list. Defaults to the root ('').
Returns:
- 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
- 404: If the requested path does not exist.
- 403: If the user is invalid.
- 500: If there is an error reading the directory contents.
- 200: JSON response containing a list of file and directory objects.
Each object includes:
- name: The name of the file or directory.
- type: 'file' or 'directory'.
- path: The relative path from the user's data root.
- size (for files): The size in bytes.
- modified (for files): The last modified timestamp (Unix epoch).
"""
requested_rel_path = request.rel_url.query.get('path', '')
# URL-decode the path parameter
try:
requested_rel_path = parse.unquote(requested_rel_path)
except Exception as e:
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
return web.Response(status=400, text="Invalid characters in path parameter")
# Check user validity and get the absolute path for the requested directory
try:
base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
if requested_rel_path:
target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
else:
target_abs_path = base_user_path
except KeyError as e:
# Invalid user detected by get_request_user_id inside get_request_user_filepath
logging.warning(f"Access denied for user: {e}")
return web.Response(status=403, text="Invalid user specified in request")
if not target_abs_path:
# Path traversal or other issue detected by get_request_user_filepath
return web.Response(status=400, text="Invalid path requested")
# Handle cases where the user directory or target path doesn't exist
if not os.path.exists(target_abs_path):
# Check if it's the base user directory that's missing (new user case)
if target_abs_path == base_user_path:
# It's okay if the base user directory doesn't exist yet, return empty list
return web.json_response([])
else:
# A specific subdirectory was requested but doesn't exist
return web.Response(status=404, text="Requested path not found")
if not os.path.isdir(target_abs_path):
return web.Response(status=400, text="Requested path is not a directory")
results = []
try:
for root, dirs, files in os.walk(target_abs_path, topdown=True):
# Process directories
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
results.append({
"name": dir_name,
"path": rel_path,
"type": "directory"
})
# Process files
for file_name in files:
file_path = os.path.join(root, file_name)
rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
entry_info = {
"name": file_name,
"path": rel_path,
"type": "file"
}
try:
stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
entry_info["size"] = stats.st_size
entry_info["modified"] = stats.st_mtime
except OSError as stat_error:
logging.warning(f"Could not stat file {file_path}: {stat_error}")
pass # Include file with available info
results.append(entry_info)
except OSError as e:
logging.error(f"Error listing directory {target_abs_path}: {e}")
return web.Response(status=500, text="Error reading directory contents")
# Sort results alphabetically, directories first then files
results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
return web.json_response(results)
def get_user_data_path(request, check_exists = False, param = "file"): def get_user_data_path(request, check_exists = False, param = "file"):
file = request.match_info.get(param, None) file = request.match_info.get(param, None)
if not file: if not file:

View File

@@ -128,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")

View File

@@ -48,6 +48,7 @@ class IO(StrEnum):
FACE_ANALYSIS = "FACE_ANALYSIS" FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX" BBOX = "BBOX"
SEGS = "SEGS" SEGS = "SEGS"
VIDEO = "VIDEO"
ANY = "*" ANY = "*"
"""Always matches any type, but at a price. """Always matches any type, but at a price.
@@ -120,6 +121,10 @@ class InputTypeOptions(TypedDict):
Available from frontend v1.17.5 Available from frontend v1.17.5
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548 Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
""" """
widgetType: NotRequired[str]
"""Specifies a type to be used for widget initialization if different from the input type.
Available from frontend v1.18.0
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
# class InputTypeNumber(InputTypeOptions): # class InputTypeNumber(InputTypeOptions):
# default: float | int # default: float | int
min: NotRequired[float] min: NotRequired[float]
@@ -269,7 +274,7 @@ class ComfyNodeABC(ABC):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
""" """
OUTPUT_IS_LIST: tuple[bool] OUTPUT_IS_LIST: tuple[bool, ...]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items. """A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list. Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
@@ -288,7 +293,7 @@ class ComfyNodeABC(ABC):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
""" """
RETURN_TYPES: tuple[IO] RETURN_TYPES: tuple[IO, ...]
"""A tuple representing the outputs of this node. """A tuple representing the outputs of this node.
Usage:: Usage::
@@ -297,12 +302,12 @@ class ComfyNodeABC(ABC):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
""" """
RETURN_NAMES: tuple[str] RETURN_NAMES: tuple[str, ...]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")`` """The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
""" """
OUTPUT_TOOLTIPS: tuple[str] OUTPUT_TOOLTIPS: tuple[str, ...]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`.""" """A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"` """The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`

View File

@@ -1345,28 +1345,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True) return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad() @torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK""" """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
old_d = None old_d = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
if cfg_pp:
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
if cfg_pp:
d = to_d(x, sigmas[i], uncond_denoised)
else:
d = to_d(x, sigmas[i], denoised) d = to_d(x, sigmas[i], denoised)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i] dt = sigmas[i + 1] - sigmas[i]
if i == 0: if i == 0:
# Euler method # Euler method
if cfg_pp:
x = denoised + d * sigmas[i + 1]
else:
x = x + d * dt x = x + d * dt
else: else:
# Gradient estimation # Gradient estimation
if cfg_pp:
d_bar = (ge_gamma - 1) * (d - old_d)
x = denoised + d * sigmas[i + 1] + d_bar * dt
else:
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt x = x + d_bar * dt
old_d = d old_d = d
return x return x
@torch.no_grad()
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
@torch.no_grad() @torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3): def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
""" """

183
comfy/ldm/chroma/layers.py Normal file
View File

@@ -0,0 +1,183 @@
import torch
from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)
class ChromaModulationOut(ModulationOut):
@classmethod
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
return cls(
shift=tensor[:, offset : offset + 1, :],
scale=tensor[:, offset + 1 : offset + 2, :],
gate=tensor[:, offset + 2 : offset + 3, :],
)
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def forward(self, x: Tensor) -> Tensor:
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
x = self.out_proj(x)
return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

271
comfy/ldm/chroma/model.py Normal file
View File

@@ -0,0 +1,271 @@
#Original code can be found on: https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
)
from .layers import (
DoubleStreamBlock,
LastLayer,
SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@dataclass
class ChromaParams:
in_channels: int
out_channels: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
in_dim: int
out_dim: int
hidden_dim: int
n_layers: int
class Chroma(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = ChromaParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.in_dim = params.in_dim
self.out_dim = params.out_dim
self.hidden_dim = params.hidden_dim
self.n_layers = params.n_layers
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
# set as nn identity for now, will overwrite it later.
self.distilled_guidance_layer = Approximator(
in_dim=self.in_dim,
hidden_dim=self.hidden_dim,
out_dim=self.out_dim,
n_layers=self.n_layers,
dtype=dtype, device=device, operations=operations
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
self.skip_mmdit = []
self.skip_dit = []
self.lite = False
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
# This function slices up the modulations tensor which has the following layout:
# single : num_single_blocks * 3 elements
# double_img : num_double_blocks * 6 elements
# double_txt : num_double_blocks * 6 elements
# final : 2 elements
if block_type == "final":
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
single_block_count = self.params.depth_single_blocks
double_block_count = self.params.depth
offset = 3 * idx
if block_type == "single":
return ChromaModulationOut.from_offset(tensor, offset)
# Double block modulations are 6 elements so we double 3 * idx.
offset *= 2
if block_type in {"double_img", "double_txt"}:
# Advance past the single block modulations.
offset += 3 * single_block_count
if block_type == "double_txt":
# Advance past the double block img modulations.
offset += 6 * double_block_count
return (
ChromaModulationOut.from_offset(tensor, offset),
ChromaModulationOut.from_offset(tensor, offset + 3),
)
raise ValueError("Bad block_type")
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
guidance: Tensor = None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
# distilled vector guidance
mod_index_length = 344
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
# guidance = guidance *
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
# get all modulation index
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
# and we need to broadcast timestep and guidance along too
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
mod_vectors = self.distilled_guidance_layer(input_vec)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if i not in self.skip_mmdit:
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
self.get_modulations(mod_vectors, "double_txt", idx=i),
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=double_mod,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
final_mod = self.get_modulations(mod_vectors, "final")
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@@ -699,10 +699,13 @@ class HiDreamImageTransformer2DModel(nn.Module):
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None, encoder_hidden_states_llama3=None,
image_cond=None,
control = None, control = None,
transformer_options = {}, transformer_options = {},
) -> torch.Tensor: ) -> torch.Tensor:
bs, c, h, w = x.shape bs, c, h, w = x.shape
if image_cond is not None:
x = torch.cat([x, image_cond], dim=-1)
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
timesteps = t timesteps = t
pooled_embeds = y pooled_embeds = y

View File

@@ -1,7 +1,6 @@
import torch import torch
from torch import nn from torch import nn
import comfy.ldm.modules.attention import comfy.ldm.modules.attention
from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit import comfy.ldm.common_dit
from einops import rearrange from einops import rearrange
import math import math
@@ -262,8 +261,8 @@ class CrossAttention(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device) self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device) self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

View File

@@ -631,6 +631,7 @@ class VaceWanModel(WanModel):
if ii is not None: if ii is not None:
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength x += c_skip * vace_strength
del c_skip
# head # head
x = self.head(x, e) x = self.head(x, e)

View File

@@ -279,6 +279,13 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k key_map["transformer.{}".format(key_lora)] = k
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
if isinstance(model, comfy.model_base.HiDream):
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
return key_map return key_map

View File

@@ -38,6 +38,7 @@ import comfy.ldm.lumina.model
import comfy.ldm.wan.model import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@@ -786,8 +787,8 @@ class PixArt(BaseModel):
return out return out
class Flux(BaseModel): class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux) super().__init__(model_config, model_type, device=device, unet_model=unet_model)
def concat_cond(self, **kwargs): def concat_cond(self, **kwargs):
try: try:
@@ -1104,4 +1105,19 @@ class HiDream(BaseModel):
conditioning_llama3 = kwargs.get("conditioning_llama3", None) conditioning_llama3 = kwargs.get("conditioning_llama3", None)
if conditioning_llama3 is not None: if conditioning_llama3 is not None:
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3) out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
image_cond = kwargs.get("concat_latent_image", None)
if image_cond is not None:
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
return out
class Chroma(Flux):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
guidance = kwargs.get("guidance", 0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out return out

View File

@@ -164,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if in_key in state_dict_keys: if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16 dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768 vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["context_in_dim"] = 4096 dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072 dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0 dit_config["mlp_ratio"] = 4.0
@@ -174,6 +176,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_dim"] = [16, 56, 56] dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000 dit_config["theta"] = 10000
dit_config["qkv_bias"] = True dit_config["qkv_bias"] = True
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
dit_config["in_dim"] = 64
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config return dit_config

View File

@@ -939,13 +939,59 @@ def force_channels_last():
#TODO #TODO
return False return False
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
STREAMS = {}
NUM_STREAMS = 1
if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS <= 1:
return None
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
stream_counters[device] = stream_counter
return s
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device: if device is None or weight.device == device:
if not copy: if not copy:
if dtype is None or weight.dtype == dtype: if dtype is None or weight.dtype == dtype:
return weight return weight
if stream is not None:
with stream:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy)
if stream is not None:
with stream:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
r = torch.empty_like(weight, dtype=dtype, device=device) r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking) r.copy_(weight, non_blocking=non_blocking)
return r return r

View File

@@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
self.linear_start = linear_start self.linear_start = linear_start
self.linear_end = linear_end self.linear_end = linear_end
self.zsnr = zsnr
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
if zsnr: if self.zsnr:
sigmas = rescale_zero_terminal_snr_sigmas(sigmas) sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
self.set_sigmas(sigmas) self.set_sigmas(sigmas)

View File

@@ -22,6 +22,7 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature from comfy.cli_args import args, PerformanceFeature
import comfy.float import comfy.float
import comfy.rmsnorm import comfy.rmsnorm
import contextlib
cast_to = comfy.model_management.cast_to #TODO: remove once no more references cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@@ -37,20 +38,31 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None: if device is None:
device = input.device device = input.device
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()
bias = None bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None: if s.bias is not None:
has_function = len(s.bias_function) > 0 has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function: if has_function:
with wf_context:
for f in s.bias_function: for f in s.bias_function:
bias = f(bias) bias = f(bias)
has_function = len(s.weight_function) > 0 has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function: if has_function:
with wf_context:
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
return weight, bias return weight, bias
class CastWeightBiasOp: class CastWeightBiasOp:

View File

@@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"] "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@@ -120,6 +120,7 @@ class CLIP:
self.layer_idx = None self.layer_idx = None
self.use_clip_schedule = False self.use_clip_schedule = False
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
self.tokenizer_options = {}
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)
@@ -127,6 +128,7 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx n.layer_idx = self.layer_idx
n.tokenizer_options = self.tokenizer_options.copy()
n.use_clip_schedule = self.use_clip_schedule n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
@@ -134,10 +136,18 @@ class CLIP:
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model) return self.patcher.add_patches(patches, strength_patch, strength_model)
def set_tokenizer_option(self, option_name, value):
self.tokenizer_options[option_name] = value
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.layer_idx = layer_idx self.layer_idx = layer_idx
def tokenize(self, text, return_word_ids=False, **kwargs): def tokenize(self, text, return_word_ids=False, **kwargs):
tokenizer_options = kwargs.get("tokenizer_options", {})
if len(self.tokenizer_options) > 0:
tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
if len(tokenizer_options) > 0:
kwargs["tokenizer_options"] = tokenizer_options
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs) return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
def add_hooks_to_dict(self, pooled_dict: dict[str]): def add_hooks_to_dict(self, pooled_dict: dict[str]):
@@ -704,6 +714,7 @@ class CLIPType(Enum):
LUMINA2 = 12 LUMINA2 = 12
WAN = 13 WAN = 13
HIDREAM = 14 HIDREAM = 14
CHROMA = 15
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -808,7 +819,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.LTXV: elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data)) clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
elif clip_type == CLIPType.PIXART: elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data)) clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
elif clip_type == CLIPType.WAN: elif clip_type == CLIPType.WAN:

View File

@@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length self.min_length = min_length
self.end_token = None self.end_token = None
self.min_padding = min_padding
empty = self.tokenizer('')["input_ids"] empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token self.tokenizer_adds_end_token = has_end_token
@@ -518,13 +519,15 @@ class SDTokenizer:
return (embed, leftover) return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
''' '''
Takes a prompt and converts it to a list of (token, weight, word id) elements. Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors. Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP Returned list has the dimensions NxM where M is the input size of CLIP
''' '''
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text) text = escape_important(text)
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)
@@ -603,10 +606,12 @@ class SDTokenizer:
#fill last batch #fill last batch
if self.end_token is not None: if self.end_token is not None:
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length: if min_padding is not None:
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length: if min_length is not None and len(batch) < min_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch))) batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
if not return_word_ids: if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@@ -634,7 +639,7 @@ class SD1Tokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@@ -28,8 +28,8 @@ class SDXLTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@@ -993,6 +993,10 @@ class WAN21_Vace(WAN21_T2V):
"model_type": "vace", "model_type": "vace",
} }
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 1.2 * self.memory_usage_factor
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device) out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out return out
@@ -1064,7 +1068,34 @@ class HiDream(supported_models_base.BASE):
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return None # TODO return None # TODO
class Chroma(supported_models_base.BASE):
unet_config = {
"image_model": "chroma",
}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] unet_extra_config = {
}
sampling_settings = {
"multiplier": 1.0,
}
latent_format = comfy.latent_formats.Flux
memory_usage_factor = 3.2
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@@ -19,8 +19,8 @@ class FluxTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@@ -16,11 +16,11 @@ class HiDreamTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@@ -49,13 +49,13 @@ class HunyuanVideoTokenizer:
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {} out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
if llama_template is None: if llama_template is None:
llama_text = self.llama_template.format(text) llama_text = self.llama_template.format(text)
else: else:
llama_text = llama_template.format(text) llama_text = llama_template.format(text)
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids) llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
embed_count = 0 embed_count = 0
for r in llama_text_tokens: for r in llama_text_tokens:
for i in range(len(r)): for i in range(len(r)):

View File

@@ -41,8 +41,8 @@ class HyditTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids) out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids) out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@@ -45,9 +45,9 @@ class SD3Tokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@@ -24,7 +24,7 @@ class BOFTAdapter(WeightAdapterBase):
) -> Optional["BOFTAdapter"]: ) -> Optional["BOFTAdapter"]:
if loaded_keys is None: if loaded_keys is None:
loaded_keys = set() loaded_keys = set()
blocks_name = "{}.boft_blocks".format(x) blocks_name = "{}.oft_blocks".format(x)
rescale_name = "{}.rescale".format(x) rescale_name = "{}.rescale".format(x)
blocks = None blocks = None
@@ -32,17 +32,18 @@ class BOFTAdapter(WeightAdapterBase):
blocks = lora[blocks_name] blocks = lora[blocks_name]
if blocks.ndim == 4: if blocks.ndim == 4:
loaded_keys.add(blocks_name) loaded_keys.add(blocks_name)
else:
blocks = None
if blocks is None:
return None
rescale = None rescale = None
if rescale_name in lora.keys(): if rescale_name in lora.keys():
rescale = lora[rescale_name] rescale = lora[rescale_name]
loaded_keys.add(rescale_name) loaded_keys.add(rescale_name)
if blocks is not None:
weights = (blocks, rescale, alpha, dora_scale) weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights) return cls(loaded_keys, weights)
else:
return None
def calculate_weight( def calculate_weight(
self, self,
@@ -71,7 +72,7 @@ class BOFTAdapter(WeightAdapterBase):
# Get r # Get r
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype) I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
# for Q = -Q^T # for Q = -Q^T
q = blocks - blocks.transpose(1, 2) q = blocks - blocks.transpose(-1, -2)
normed_q = q normed_q = q
if alpha > 0: # alpha in boft/bboft is for constraint if alpha > 0: # alpha in boft/bboft is for constraint
q_norm = torch.norm(q) + 1e-8 q_norm = torch.norm(q) + 1e-8
@@ -79,9 +80,8 @@ class BOFTAdapter(WeightAdapterBase):
normed_q = q * alpha / q_norm normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse() # use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse() r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(original_weight) r = r.to(weight)
inp = org = weight
inp = org = original_weight
r_b = boft_b//2 r_b = boft_b//2
for i in range(boft_m): for i in range(boft_m):
@@ -91,14 +91,14 @@ class BOFTAdapter(WeightAdapterBase):
if strength != 1: if strength != 1:
bi = bi * strength + (1-strength) * I bi = bi * strength + (1-strength) * I
inp = ( inp = (
inp.unflatten(-1, (-1, g, k)) inp.unflatten(0, (-1, g, k))
.transpose(-2, -1) .transpose(1, 2)
.flatten(-3) .flatten(0, 2)
.unflatten(-1, (-1, boft_b)) .unflatten(0, (-1, boft_b))
) )
inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi) inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
inp = ( inp = (
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
) )
if rescale is not None: if rescale is not None:
@@ -109,7 +109,7 @@ class BOFTAdapter(WeightAdapterBase):
if dora_scale is not None: if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else: else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e)) logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight return weight

View File

@@ -32,17 +32,18 @@ class OFTAdapter(WeightAdapterBase):
blocks = lora[blocks_name] blocks = lora[blocks_name]
if blocks.ndim == 3: if blocks.ndim == 3:
loaded_keys.add(blocks_name) loaded_keys.add(blocks_name)
else:
blocks = None
if blocks is None:
return None
rescale = None rescale = None
if rescale_name in lora.keys(): if rescale_name in lora.keys():
rescale = lora[rescale_name] rescale = lora[rescale_name]
loaded_keys.add(rescale_name) loaded_keys.add(rescale_name)
if blocks is not None:
weights = (blocks, rescale, alpha, dora_scale) weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights) return cls(loaded_keys, weights)
else:
return None
def calculate_weight( def calculate_weight(
self, self,
@@ -79,16 +80,17 @@ class OFTAdapter(WeightAdapterBase):
normed_q = q * alpha / q_norm normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse() # use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse() r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(original_weight) r = r.to(weight)
_, *shape = weight.shape
lora_diff = torch.einsum( lora_diff = torch.einsum(
"k n m, k n ... -> k m ...", "k n m, k n ... -> k m ...",
(r * strength) - strength * I, (r * strength) - strength * I,
original_weight, weight.view(block_num, block_size, *shape),
) ).view(-1, *shape)
if dora_scale is not None: if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else: else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e)) logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight return weight

View File

@@ -0,0 +1,8 @@
from .basic_types import ImageInput, AudioInput
from .video_types import VideoInput
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
]

View File

@@ -0,0 +1,20 @@
import torch
from typing import TypedDict
ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""
class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""
waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""
sample_rate: int

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
Abstract base class for video input types.
"""
@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).
Returns:
VideoComponents containing images, audio, and frame rate
"""
pass
@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]

View File

@@ -0,0 +1,7 @@
from .video_types import VideoFromFile, VideoFromComponents
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]

View File

@@ -0,0 +1,224 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.input import AudioInput
import av
import io
import json
import numpy as np
import torch
from comfy_api.input import VideoInput
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
with av.open(path, mode='w', options={"movflags": "use_metadata_tags"}) as output_container:
# Copy over the original metadata
for key, value in container.metadata.items():
if metadata is None or key not in metadata:
output_container.metadata[key] = value
# Add our new metadata
if metadata is not None:
for key, value in metadata.items():
if isinstance(value, str):
output_container.metadata[key] = value
else:
output_container.metadata[key] = json.dumps(value)
# Add streams to the new container
stream_map = {}
for stream in streams:
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
stream_map[stream] = out_stream
# Write packets to the new container
for packet in container.demux():
if packet.stream in stream_map and packet.dts is not None:
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
self.__components = components
def get_components(self) -> VideoComponents:
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = self.__components.audio['waveform'][0, 0, start:end].unsqueeze(0).numpy()
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)

View File

@@ -0,0 +1,8 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
__all__ = [
# Utility Types
"VideoContainer",
"VideoCodec",
"VideoComponents",
]

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"
H264 = "h264"
@classmethod
def as_input(cls) -> list[str]:
"""
Returns a list of codec names that can be used as node input.
"""
return [member.value for member in cls]
class VideoContainer(str, Enum):
AUTO = "auto"
MP4 = "mp4"
@classmethod
def as_input(cls) -> list[str]:
"""
Returns a list of container names that can be used as node input.
"""
return [member.value for member in cls]
@classmethod
def get_extension(cls, value) -> str:
"""
Returns the file extension for the container.
"""
if isinstance(value, str):
value = cls(value)
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
return "mp4"
return ""
@dataclass
class VideoComponents:
"""
Dataclass representing the components of a video.
"""
images: ImageInput
frame_rate: Fraction
audio: Optional[AudioInput] = None
metadata: Optional[dict] = None

View File

@@ -297,6 +297,10 @@ class SynchronousOperation(Generic[T, R]):
# Convert request model to dict, but use None for EmptyRequest # Convert request model to dict, but use None for EmptyRequest
request_dict = None if isinstance(self.request, EmptyRequest) else self.request.model_dump(exclude_none=True) request_dict = None if isinstance(self.request, EmptyRequest) else self.request.model_dump(exclude_none=True)
if request_dict:
for key, value in request_dict.items():
if isinstance(value, Enum):
request_dict[key] = value.value
# Debug log for request # Debug log for request
logging.debug(f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}") logging.debug(f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}")

View File

@@ -1,21 +1,22 @@
import base64
import io import io
import math
from inspect import cleandoc from inspect import cleandoc
from comfy.utils import common_upscale import numpy as np
import requests
import torch
from PIL import Image
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from comfy.utils import common_upscale
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest, OpenAIImageEditRequest,
OpenAIImageGenerationResponse OpenAIImageGenerationRequest,
OpenAIImageGenerationResponse,
) )
from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation
import numpy as np
from PIL import Image
import requests
import torch
import math
import base64
def downscale_input(image): def downscale_input(image):
samples = image.movedim(-1,1) samples = image.movedim(-1,1)
@@ -331,6 +332,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
"default": None, "default": None,
"tooltip": "Optional mask for inpainting (white areas will be replaced)", "tooltip": "Optional mask for inpainting (white areas will be replaced)",
}), }),
"moderation": (IO.COMBO, {
"options": ["low","auto"],
"default": "low",
"tooltip": "Moderation level",
}),
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG" "auth_token": "AUTH_TOKEN_COMFY_ORG"
@@ -343,7 +349,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True API_NODE = True
def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None): def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None, moderation="low"):
model = "gpt-image-1" model = "gpt-image-1"
path = "/proxy/openai/images/generations" path = "/proxy/openai/images/generations"
request_class = OpenAIImageGenerationRequest request_class = OpenAIImageGenerationRequest
@@ -415,6 +421,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
n=n, n=n,
seed=seed, seed=seed,
size=size, size=size,
moderation=moderation,
), ),
files=files if files else None, files=files if files else None,
auth_token=auth_token auth_token=auth_token

View File

@@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet:
c.append(n) c.append(n)
return (c, ) return (c, )
class T5TokenizerOptions:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip": ("CLIP", ),
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
}
}
RETURN_TYPES = ("CLIP",)
FUNCTION = "set_options"
def set_options(self, clip, min_padding, min_length):
clip = clip.clone()
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
return (clip, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
"T5TokenizerOptions": T5TokenizerOptions,
} }

View File

@@ -1,3 +1,4 @@
import math
import comfy.samplers import comfy.samplers
import comfy.sample import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sampling as k_diffusion_sampling
@@ -249,6 +250,55 @@ class SetFirstSigma:
sigmas[0] = sigma sigmas[0] = sigma
return (sigmas, ) return (sigmas, )
class ExtendIntermediateSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
"steps": ("INT", {"default": 2, "min": 1, "max": 100}),
"start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}),
"end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}),
"spacing": (['linear', 'cosine', 'sine'],),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "extend"
def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str):
if start_at_sigma < 0:
start_at_sigma = float("inf")
interpolator = {
'linear': lambda x: x,
'cosine': lambda x: torch.sin(x*math.pi/2),
'sine': lambda x: 1 - torch.cos(x*math.pi/2)
}[spacing]
# linear space for our interpolation function
x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
computed_spacing = interpolator(x)
extended_sigmas = []
for i in range(len(sigmas) - 1):
sigma_current = sigmas[i]
sigma_next = sigmas[i+1]
extended_sigmas.append(sigma_current)
if end_at_sigma <= sigma_current <= start_at_sigma:
interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
extended_sigmas.extend(interpolated_steps.tolist())
# Add the last sigma value
if len(sigmas) > 0:
extended_sigmas.append(sigmas[-1])
extended_sigmas = torch.FloatTensor(extended_sigmas)
return (extended_sigmas,)
class KSamplerSelect: class KSamplerSelect:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@@ -735,6 +785,7 @@ NODE_CLASS_MAPPINGS = {
"SplitSigmasDenoise": SplitSigmasDenoise, "SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,
"SetFirstSigma": SetFirstSigma, "SetFirstSigma": SetFirstSigma,
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
"CFGGuider": CFGGuider, "CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider, "DualCFGGuider": DualCFGGuider,

View File

@@ -38,6 +38,7 @@ class LTXVImgToVideo:
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
@@ -46,7 +47,7 @@ class LTXVImgToVideo:
CATEGORY = "conditioning/video_models" CATEGORY = "conditioning/video_models"
FUNCTION = "generate" FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size): def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
@@ -59,7 +60,7 @@ class LTXVImgToVideo:
dtype=torch.float32, dtype=torch.float32,
device=latent.device, device=latent.device,
) )
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0 conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
@@ -152,6 +153,15 @@ class LTXVAddGuide:
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = self.get_latent_index(
cond=positive,
latent_length=latent_image.shape[2],
guide_length=guiding_latent.shape[2],
frame_idx=frame_idx,
scale_factors=scale_factors,
)
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)

View File

@@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
metadata["modelspec.predict_key"] = "epsilon" metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v" metadata["modelspec.predict_key"] = "v"
extra_keys["v_pred"] = torch.tensor([])
if getattr(model_sampling, "zsnr", False):
extra_keys["ztsnr"] = torch.tensor([])
if not args.disable_metadata: if not args.disable_metadata:
metadata["prompt"] = prompt_info metadata["prompt"] = prompt_info
@@ -273,7 +276,7 @@ class CLIPSave:
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True) comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd() clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]: for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys())) k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {} current_clip_sd = {}
for x in k: for x in k:

View File

@@ -20,13 +20,14 @@ def loglinear_interp(t_steps, num_steps):
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001], NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001], "Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
} }
class OptimalStepsScheduler: class OptimalStepsScheduler:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": return {"required":
{"model_type": (["FLUX", "Wan"], ), {"model_type": (["FLUX", "Wan", "Chroma"], ),
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}), "steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
} }

View File

@@ -141,6 +141,7 @@ class Quantize:
CATEGORY = "image/postprocessing" CATEGORY = "image/postprocessing"
@staticmethod
def bayer(im, pal_im, order): def bayer(im, pal_im, order):
def normalized_bayer_matrix(n): def normalized_bayer_matrix(n):
if n == 0: if n == 0:

View File

@@ -0,0 +1,43 @@
import json
from comfy.comfy_types.node_typing import IO
# Preview Any - original implement from
# https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
class PreviewAny():
@classmethod
def INPUT_TYPES(cls):
return {
"required": {"source": (IO.ANY, {})},
}
RETURN_TYPES = ()
FUNCTION = "main"
OUTPUT_NODE = True
CATEGORY = "utils"
def main(self, source=None):
value = 'None'
if isinstance(source, str):
value = source
elif isinstance(source, (int, float, bool)):
value = str(source)
elif source is not None:
try:
value = json.dumps(source)
except Exception:
try:
value = str(source)
except Exception:
value = 'source exists, but could not be serialized.'
return {"ui": {"text": (value,)}}
NODE_CLASS_MAPPINGS = {
"PreviewAny": PreviewAny,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PreviewAny": "Preview Any",
}

View File

@@ -5,9 +5,13 @@ import av
import torch import torch
import folder_paths import folder_paths
import json import json
from typing import Optional, Literal
from fractions import Fraction from fractions import Fraction
from comfy.comfy_types import FileLocator from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
from comfy_api.input import ImageInput, AudioInput, VideoInput
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
from comfy.cli_args import args
class SaveWEBM: class SaveWEBM:
def __init__(self): def __init__(self):
@@ -75,7 +79,163 @@ class SaveWEBM:
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
class SaveVideo(ComfyNodeABC):
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type: Literal["output"] = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
"format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
"codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ()
FUNCTION = "save_video"
OUTPUT_NODE = True
CATEGORY = "image/video"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
self.output_dir,
width,
height
)
results: list[FileLocator] = list()
saved_metadata = None
if not args.disable_metadata:
metadata = {}
if extra_pnginfo is not None:
metadata.update(extra_pnginfo)
if prompt is not None:
metadata["prompt"] = prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=format,
codec=codec,
metadata=saved_metadata
)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "images": results, "animated": (True,) } }
class CreateVideo(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
},
"optional": {
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
}
}
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "create_video"
CATEGORY = "image/video"
DESCRIPTION = "Create a video from images."
def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None):
return (VideoFromComponents(
VideoComponents(
images=images,
audio=audio,
frame_rate=Fraction(fps),
)
),)
class GetVideoComponents(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
}
}
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
RETURN_NAMES = ("images", "audio", "fps")
FUNCTION = "get_components"
CATEGORY = "image/video"
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
def get_components(self, video: VideoInput):
components = video.get_components()
return (components.images, components.audio, float(components.frame_rate))
class LoadVideo(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["video"])
return {"required":
{"file": (sorted(files), {"video_upload": True})},
}
CATEGORY = "image/video"
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "load_video"
def load_video(self, file):
video_path = folder_paths.get_annotated_filepath(file)
return (VideoFromFile(video_path),)
@classmethod
def IS_CHANGED(cls, file):
video_path = folder_paths.get_annotated_filepath(file)
mod_time = os.path.getmtime(video_path)
# Instead of hashing the file, we can just use the modification time to avoid
# rehashing large files.
return mod_time
@classmethod
def VALIDATE_INPUTS(cls, file):
if not folder_paths.exists_annotated_filepath(file):
return "Invalid video file: {}".format(file)
return True
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SaveWEBM": SaveWEBM, "SaveWEBM": SaveWEBM,
"SaveVideo": SaveVideo,
"CreateVideo": CreateVideo,
"GetVideoComponents": GetVideoComponents,
"LoadVideo": LoadVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveVideo": "Save Video",
"CreateVideo": "Create Video",
"GetVideoComponents": "Get Video Components",
"LoadVideo": "Load Video",
} }

View File

@@ -20,7 +20,7 @@ class WebcamCapture(nodes.LoadImage):
CATEGORY = "image" CATEGORY = "image"
def load_capture(s, image, **kwargs): def load_capture(self, image, **kwargs):
return super().load_image(folder_paths.get_annotated_filepath(image)) return super().load_image(folder_paths.get_annotated_filepath(image))

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.30" __version__ = "0.3.31"

View File

@@ -4,7 +4,7 @@ import os
import time import time
import mimetypes import mimetypes
import logging import logging
from typing import Literal from typing import Literal, List
from collections.abc import Collection from collections.abc import Collection
from comfy.cli_args import args from comfy.cli_args import args
@@ -141,7 +141,7 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory() return get_input_directory()
return None return None
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]: def filter_files_content_types(files: list[str], content_types: List[Literal["image", "video", "audio", "model"]]) -> list[str]:
""" """
Example: Example:
files = os.listdir(folder_paths.get_input_directory()) files = os.listdir(folder_paths.get_input_directory())

17
hook_breaker_ac10a0.py Normal file
View File

@@ -0,0 +1,17 @@
# Prevent custom nodes from hooking anything important
import comfy.model_management
HOOK_BREAK = [(comfy.model_management, "cast_to")]
SAVED_FUNCTIONS = []
def save_functions():
for f in HOOK_BREAK:
SAVED_FUNCTIONS.append((f[0], f[1], getattr(f[0], f[1])))
def restore_functions():
for f in SAVED_FUNCTIONS:
setattr(f[0], f[1], f[2])

View File

@@ -13,7 +13,7 @@ import logging
import sys import sys
if __name__ == "__main__": if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes. #NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1' os.environ['DO_NOT_TRACK'] = '1'
@@ -141,7 +141,7 @@ import nodes
import comfy.model_management import comfy.model_management
import comfyui_version import comfyui_version
import app.logger import app.logger
import hook_breaker_ac10a0
def cuda_malloc_warning(): def cuda_malloc_warning():
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
@@ -215,6 +215,7 @@ def prompt_worker(q, server_instance):
comfy.model_management.soft_empty_cache() comfy.model_management.soft_empty_cache()
last_gc_collect = current_time last_gc_collect = current_time
need_gc = False need_gc = False
hook_breaker_ac10a0.restore_functions()
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
@@ -268,7 +269,9 @@ def start_comfyui(asyncio_loop=None):
prompt_server = server.PromptServer(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server) q = execution.PromptQueue(prompt_server)
hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
hook_breaker_ac10a0.restore_functions()
cuda_malloc_warning() cuda_malloc_warning()

View File

@@ -917,7 +917,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@@ -2258,6 +2258,7 @@ def init_builtin_extra_nodes():
"nodes_optimalsteps.py", "nodes_optimalsteps.py",
"nodes_hidream.py", "nodes_hidream.py",
"nodes_fresca.py", "nodes_fresca.py",
"nodes_preview_any.py",
] ]
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.30" version = "0.3.31"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"
@@ -12,6 +12,7 @@ documentation = "https://docs.comfy.org/"
[tool.ruff] [tool.ruff]
lint.select = [ lint.select = [
"N805", # invalid-first-argument-name-for-method
"S307", # suspicious-eval-usage "S307", # suspicious-eval-usage
"S102", # exec "S102", # exec
"T", # print-usage "T", # print-usage

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.17.11 comfyui-frontend-package==1.18.6
comfyui-workflow-templates==0.1.3 comfyui-workflow-templates==0.1.3
torch torch
torchsde torchsde
@@ -22,5 +22,5 @@ psutil
kornia>=0.7.1 kornia>=0.7.1
spandrel spandrel
soundfile soundfile
av>=14.1.0 av>=14.2.0
pydantic~=2.0 pydantic~=2.0

View File

@@ -229,3 +229,61 @@ async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
assert not os.path.exists(tmp_path / "source.txt") assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f: with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content" assert f.read() == "test content"
async def test_listuserdata_v2_empty_root(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata")
assert resp.status == 200
assert await resp.json() == []
async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=does_not_exist")
assert resp.status == 404
async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
(tmp_path / "test_dir" / "file1.txt").write_text("content")
(tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content")
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=test_dir")
assert resp.status == 200
data = await resp.json()
file_paths = {item["path"] for item in data if item["type"] == "file"}
assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"}
async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch):
# Force backslash as os separator
monkeypatch.setattr(os, 'sep', '\\')
monkeypatch.setattr(os.path, 'sep', '\\')
os.makedirs(tmp_path / "test_dir" / "subdir")
(tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x")
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=test_dir")
assert resp.status == 200
data = await resp.json()
for item in data:
assert "/" in item["path"]
assert "\\" not in item["path"]\
async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path):
# Create a directory with a space in its name and a file inside
os.makedirs(tmp_path / "my dir")
(tmp_path / "my dir" / "file.txt").write_text("content")
client = await aiohttp_client(app)
# Use URL-encoded space in path parameter
resp = await client.get("/v2/userdata?path=my%20dir&recurse=false")
assert resp.status == 200
data = await resp.json()
assert len(data) == 1
entry = data[0]
assert entry["name"] == "file.txt"
# Ensure the path is correctly decoded and uses forward slash
assert entry["path"] == "my dir/file.txt"