Compare commits

..

16 Commits

Author SHA1 Message Date
comfyanonymous
ef85058e97 Bump ComfyUI version to v0.3.13 2025-01-29 16:07:12 -05:00
comfyanonymous
f9230bd357 Update the python version in some workflows. 2025-01-29 15:54:13 -05:00
comfyanonymous
537c27cbf3 Bump default cuda version in standalone package to 126. 2025-01-29 08:13:33 -05:00
comfyanonymous
6ff2e4d550 Remove logging call added in last commit.
This is called before the logging is set up so it messes up some things.
2025-01-29 08:08:01 -05:00
filtered
222f48c0f2 Allow changing folder_paths.base_path via command line argument. (#6600)
* Reimpl. CLI arg directly inside folder_paths.

* Update tests to use CLI arg mocking.

* Revert last-minute refactor.

* Fix test state polution.
2025-01-29 08:06:28 -05:00
comfyanonymous
13fd4d6e45 More friendly error messages for corrupted safetensors files. 2025-01-28 09:41:09 -05:00
Bradley Reynolds
1210d094c7 Convert latents_ubyte to 8-bit unsigned int before converting to CPU (#6300)
* Convert latents_ubyte to 8-bit unsigned int before converting to CPU

* Only convert to unint8 if directml_enabled
2025-01-28 08:22:54 -05:00
comfyanonymous
255edf2246 Lower minimum ratio of loaded weights on Nvidia. 2025-01-27 05:26:51 -05:00
comfyanonymous
4f011b9a00 Better CLIPTextEncode error when clip input is None. 2025-01-26 06:04:57 -05:00
comfyanonymous
67feb05299 Remove redundant code. 2025-01-25 19:04:53 -05:00
comfyanonymous
6d21740346 Print ComfyUI version. 2025-01-25 15:03:57 -05:00
comfyanonymous
7fbf4b72fe Update nightly pytorch ROCm command in Readme. 2025-01-24 06:15:54 -05:00
comfyanonymous
14ca5f5a10 Remove useless code. 2025-01-24 06:15:54 -05:00
filtered
ce557cfb88 Remove redundant code (#6576) 2025-01-23 05:57:41 -05:00
comfyanonymous
96e2a45193 Remove useless code. 2025-01-23 05:56:23 -05:00
Chenlei Hu
dfa2b6d129 Remove unused function lcm in conds.py (#6572) 2025-01-23 05:54:09 -05:00
19 changed files with 117 additions and 137 deletions

View File

@@ -12,7 +12,7 @@ on:
description: 'CUDA version' description: 'CUDA version'
required: true required: true
type: string type: string
default: "124" default: "126"
python_minor: python_minor:
description: 'Python minor version' description: 'Python minor version'
required: true required: true

View File

@@ -18,7 +18,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"] python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@@ -18,7 +18,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: '3.10' python-version: '3.12'
- name: Install requirements - name: Install requirements
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip

View File

@@ -17,7 +17,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "124" default: "126"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "124" default: "126"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'

View File

@@ -154,9 +154,9 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements: This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
### Intel GPUs (Windows and Linux) ### Intel GPUs (Windows and Linux)

View File

@@ -43,10 +43,11 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.") parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.") parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
@@ -176,7 +177,7 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.", help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
) )
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.") parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()

View File

@@ -3,9 +3,6 @@ import math
import comfy.utils import comfy.utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular: class CONDRegular:
def __init__(self, cond): def __init__(self, cond):
self.cond = cond self.cond = cond

View File

@@ -4,105 +4,6 @@ import logging
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================# # ================#
# VAE Conversion # # VAE Conversion #
# ================# # ================#
@@ -213,6 +114,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2} code2idx = {"q": 0, "k": 1, "v": 2}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda # This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors): def cat_tensors(tensors):
x = 0 x = 0
@@ -229,6 +131,7 @@ def cat_tensors(tensors):
return out return out
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {} new_state_dict = {}
capture_qkv_weight = {} capture_qkv_weight = {}
@@ -284,5 +187,3 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
def convert_text_enc_state_dict(text_enc_dict): def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict return text_enc_dict

View File

@@ -702,9 +702,6 @@ class Decoder(nn.Module):
padding=1) padding=1)
def forward(self, z, **kwargs): def forward(self, z, **kwargs):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding # timestep embedding
temb = None temb = None

View File

@@ -218,7 +218,7 @@ def is_amd():
MIN_WEIGHT_MEMORY_RATIO = 0.4 MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia(): if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.2 MIN_WEIGHT_MEMORY_RATIO = 0.1
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention: if args.use_pytorch_cross_attention:
@@ -535,14 +535,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
vram_set_state = vram_state vram_set_state = vram_state
lowvram_model_memory = 0 lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load: if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
loaded_memory = loaded_model.model_loaded_memory() loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1 lowvram_model_memory = 0.1

View File

@@ -50,7 +50,16 @@ def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type) try:
sd = safetensors.torch.load_file(ckpt, device=device.type)
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
if "MetadataIncompleteBuffer" in message:
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
raise e
else: else:
if safe_load or ALWAYS_SAFE_LOAD: if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) pl_sd = torch.load(ckpt, map_location=device, weights_only=True)

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.12" __version__ = "0.3.13"

View File

@@ -7,12 +7,18 @@ import logging
from typing import Literal from typing import Literal
from collections.abc import Collection from collections.abc import Collection
from comfy.cli_args import args
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {} folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
env_base_path = os.environ.get("COMFYUI_FOLDERS_BASE_PATH") # --base-directory - Resets all default paths configured in folder_paths with a new base path
base_path = os.path.dirname(os.path.realpath(__file__)) if env_base_path is None else env_base_path if args.base_directory:
base_path = os.path.abspath(args.base_directory)
else:
base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(base_path, "models") models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions) folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
@@ -40,10 +46,10 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_directory = os.path.join(base_path, "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") user_directory = os.path.join(base_path, "user")
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}

View File

@@ -12,7 +12,10 @@ MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image): def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255 .mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) )
if comfy.model_management.directml_enabled:
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
return Image.fromarray(latents_ubyte.numpy()) return Image.fromarray(latents_ubyte.numpy())

View File

@@ -138,6 +138,8 @@ import server
from server import BinaryEventTypes from server import BinaryEventTypes
import nodes import nodes
import comfy.model_management import comfy.model_management
import comfyui_version
def cuda_malloc_warning(): def cuda_malloc_warning():
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
@@ -292,6 +294,7 @@ def start_comfyui(asyncio_loop=None):
if __name__ == "__main__": if __name__ == "__main__":
# Running directly, just start ComfyUI. # Running directly, just start ComfyUI.
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
event_loop, _, start_all_func = start_comfyui() event_loop, _, start_all_func = start_comfyui()
try: try:
event_loop.run_until_complete(start_all_func()) event_loop.run_until_complete(start_all_func())

View File

@@ -63,6 +63,8 @@ class CLIPTextEncode(ComfyNodeABC):
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
def encode(self, clip, text): def encode(self, clip, text):
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
tokens = clip.tokenize(text) tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens), ) return (clip.encode_from_tokens_scheduled(tokens), )

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.12" version = "0.3.13"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -1,19 +1,23 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak ### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down # TODO(yoland): clean up this after I get back down
import sys
import pytest import pytest
import os import os
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
from importlib import reload
import folder_paths import folder_paths
import comfy.cli_args
from comfy.options import enable_args_parsing
enable_args_parsing()
@pytest.fixture() @pytest.fixture()
def clear_folder_paths(): def clear_folder_paths():
# Clear the global dictionary before each test to ensure isolation # Reload the module after each test to ensure isolation
original = folder_paths.folder_names_and_paths.copy()
folder_paths.folder_names_and_paths.clear()
yield yield
folder_paths.folder_names_and_paths = original reload(folder_paths)
@pytest.fixture @pytest.fixture
def temp_dir(): def temp_dir():
@@ -21,7 +25,21 @@ def temp_dir():
yield tmpdirname yield tmpdirname
def test_get_directory_by_type(): @pytest.fixture
def set_base_dir():
def _set_base_dir(base_dir):
# Mock CLI args
with patch.object(sys, 'argv', ["main.py", "--base-directory", base_dir]):
reload(comfy.cli_args)
reload(folder_paths)
yield _set_base_dir
# Reload the modules after each test to ensure isolation
with patch.object(sys, 'argv', ["main.py"]):
reload(comfy.cli_args)
reload(folder_paths)
def test_get_directory_by_type(clear_folder_paths):
test_dir = "/test/dir" test_dir = "/test/dir"
folder_paths.set_output_directory(test_dir) folder_paths.set_output_directory(test_dir)
assert folder_paths.get_directory_by_type("output") == test_dir assert folder_paths.get_directory_by_type("output") == test_dir
@@ -96,3 +114,49 @@ def test_get_save_image_path(temp_dir):
assert counter == 1 assert counter == 1
assert subfolder == "" assert subfolder == ""
assert filename_prefix == "test" assert filename_prefix == "test"
def test_base_path_changes(set_base_dir):
test_dir = os.path.abspath("/test/dir")
set_base_dir(test_dir)
assert folder_paths.base_path == test_dir
assert folder_paths.models_dir == os.path.join(test_dir, "models")
assert folder_paths.input_directory == os.path.join(test_dir, "input")
assert folder_paths.output_directory == os.path.join(test_dir, "output")
assert folder_paths.temp_directory == os.path.join(test_dir, "temp")
assert folder_paths.user_directory == os.path.join(test_dir, "user")
assert os.path.join(test_dir, "custom_nodes") in folder_paths.get_folder_paths("custom_nodes")
for name in ["checkpoints", "loras", "vae", "configs", "embeddings", "controlnet", "classifiers"]:
assert folder_paths.get_folder_paths(name)[0] == os.path.join(test_dir, "models", name)
def test_base_path_change_clears_old(set_base_dir):
test_dir = os.path.abspath("/test/dir")
set_base_dir(test_dir)
assert len(folder_paths.get_folder_paths("custom_nodes")) == 1
single_model_paths = [
"checkpoints",
"loras",
"vae",
"configs",
"clip_vision",
"style_models",
"diffusers",
"vae_approx",
"gligen",
"upscale_models",
"embeddings",
"hypernetworks",
"photomaker",
"classifiers",
]
for name in single_model_paths:
assert len(folder_paths.get_folder_paths(name)) == 1
for name in ["controlnet", "diffusion_models", "text_encoders"]:
assert len(folder_paths.get_folder_paths(name)) == 2