Compare commits
3 Commits
yo-lora-tr
...
annoate_ge
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
522d923948 | ||
|
|
c05c9b552b | ||
|
|
27598702e9 |
@@ -37,8 +37,6 @@ class IO(StrEnum):
|
|||||||
CONTROL_NET = "CONTROL_NET"
|
CONTROL_NET = "CONTROL_NET"
|
||||||
VAE = "VAE"
|
VAE = "VAE"
|
||||||
MODEL = "MODEL"
|
MODEL = "MODEL"
|
||||||
LORA_MODEL = "LORA_MODEL"
|
|
||||||
LOSS_MAP = "LOSS_MAP"
|
|
||||||
CLIP_VISION = "CLIP_VISION"
|
CLIP_VISION = "CLIP_VISION"
|
||||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||||
STYLE_MODEL = "STYLE_MODEL"
|
STYLE_MODEL = "STYLE_MODEL"
|
||||||
|
|||||||
@@ -797,15 +797,12 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
if self.training:
|
x = block(
|
||||||
x = torch.utils.checkpoint.checkpoint(block, x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, use_reentrant=False)
|
x,
|
||||||
else:
|
emb_B_D,
|
||||||
x = block(
|
crossattn_emb,
|
||||||
x,
|
crossattn_mask,
|
||||||
emb_B_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
crossattn_emb,
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
crossattn_mask,
|
)
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
|
||||||
)
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -750,7 +750,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x = n + x
|
x += n
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
@@ -790,12 +790,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x = n + x
|
x += n
|
||||||
if self.is_res:
|
if self.is_res:
|
||||||
x_skip = x
|
x_skip = x
|
||||||
x = self.ff(self.norm3(x))
|
x = self.ff(self.norm3(x))
|
||||||
if self.is_res:
|
if self.is_res:
|
||||||
x = x_skip + x
|
x += x_skip
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -17,26 +17,23 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from typing import Optional, Callable
|
||||||
import collections
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Callable, Optional
|
import collections
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import comfy.float
|
|
||||||
import comfy.hooks
|
|
||||||
import comfy.lora
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.patcher_extension
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.float
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.lora
|
||||||
|
import comfy.hooks
|
||||||
|
import comfy.patcher_extension
|
||||||
|
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
|
|||||||
23
comfy/sd.py
23
comfy/sd.py
@@ -986,28 +986,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||||
"""
|
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sd (dict): State dictionary containing model weights and configuration
|
|
||||||
model_options (dict, optional): Additional options for model loading. Supports:
|
|
||||||
- dtype: Override model data type
|
|
||||||
- custom_operations: Custom model operations
|
|
||||||
- fp8_optimizations: Enable FP8 optimizations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
|
||||||
Returns None if the model configuration cannot be detected.
|
|
||||||
|
|
||||||
The function:
|
|
||||||
1. Detects and handles different model formats (regular, diffusers, mmdit)
|
|
||||||
2. Configures model dtype based on parameters and device capabilities
|
|
||||||
3. Handles weight conversion and device placement
|
|
||||||
4. Manages model optimization settings
|
|
||||||
5. Loads weights and returns a device-managed model instance
|
|
||||||
"""
|
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import nodes
|
from __future__ import annotations
|
||||||
|
from typing import Type, Literal
|
||||||
|
|
||||||
|
import nodes
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
|
|
||||||
class DependencyCycleError(Exception):
|
class DependencyCycleError(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -54,7 +57,22 @@ class DynamicPrompt:
|
|||||||
def get_original_prompt(self):
|
def get_original_prompt(self):
|
||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
def get_input_info(class_def, input_name, valid_inputs=None):
|
def get_input_info(
|
||||||
|
class_def: Type[ComfyNodeABC],
|
||||||
|
input_name: str,
|
||||||
|
valid_inputs: InputTypeDict | None = None
|
||||||
|
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
|
||||||
|
"""Get the input type, category, and extra info for a given input name.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
class_def: The class definition of the node.
|
||||||
|
input_name: The name of the input to get info for.
|
||||||
|
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
|
||||||
|
"""
|
||||||
|
|
||||||
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
||||||
input_info = None
|
input_info = None
|
||||||
input_category = None
|
input_category = None
|
||||||
@@ -126,7 +144,7 @@ class TopologicalSort:
|
|||||||
from_node_id, from_socket = value
|
from_node_id, from_socket = value
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
continue
|
continue
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
node_ids.append(from_node_id)
|
node_ids.append(from_node_id)
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
# https://github.com/WeichenFan/CFG-Zero-star
|
|
||||||
def optimized_scale(positive, negative):
|
|
||||||
positive_flat = positive.reshape(positive.shape[0], -1)
|
|
||||||
negative_flat = negative.reshape(negative.shape[0], -1)
|
|
||||||
|
|
||||||
# Calculate dot production
|
|
||||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
|
||||||
|
|
||||||
# Squared norm of uncondition
|
|
||||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
|
||||||
|
|
||||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
|
||||||
st_star = dot_product / squared_norm
|
|
||||||
|
|
||||||
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
|
||||||
|
|
||||||
class CFGZeroStar:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"model": ("MODEL",),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
RETURN_NAMES = ("patched_model",)
|
|
||||||
FUNCTION = "patch"
|
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
|
|
||||||
def patch(self, model):
|
|
||||||
m = model.clone()
|
|
||||||
def cfg_zero_star(args):
|
|
||||||
guidance_scale = args['cond_scale']
|
|
||||||
x = args['input']
|
|
||||||
cond_p = args['cond_denoised']
|
|
||||||
uncond_p = args['uncond_denoised']
|
|
||||||
out = args["denoised"]
|
|
||||||
alpha = optimized_scale(x - cond_p, x - uncond_p)
|
|
||||||
|
|
||||||
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
|
||||||
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
|
||||||
return (m, )
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"CFGZeroStar": CFGZeroStar
|
|
||||||
}
|
|
||||||
@@ -1,646 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import safetensors
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
from PIL.PngImagePlugin import PngInfo
|
|
||||||
|
|
||||||
import comfy.samplers
|
|
||||||
import comfy.utils
|
|
||||||
import comfy_extras.nodes_custom_sampler
|
|
||||||
import folder_paths
|
|
||||||
import node_helpers
|
|
||||||
from comfy.cli_args import args
|
|
||||||
from comfy.comfy_types.node_typing import IO
|
|
||||||
|
|
||||||
|
|
||||||
class TrainSampler(comfy.samplers.Sampler):
|
|
||||||
|
|
||||||
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
|
||||||
self.loss_fn = loss_fn
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.loss_callback = loss_callback
|
|
||||||
|
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
|
||||||
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
|
||||||
torch.zeros_like(sigmas),
|
|
||||||
torch.zeros_like(noise, requires_grad=True),
|
|
||||||
latent_image,
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure model is in training mode and computing gradients
|
|
||||||
denoised = model_wrap(noise, sigmas, **extra_args)
|
|
||||||
try:
|
|
||||||
loss = self.loss_fn(denoised, latent.clone())
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "does not require grad and does not have a grad_fn" in str(e):
|
|
||||||
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
|
||||||
loss.backward()
|
|
||||||
logging.info(f"Current Training Loss: {loss.item():.6f}")
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(loss.item())
|
|
||||||
|
|
||||||
self.optimizer.step()
|
|
||||||
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
|
||||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
|
||||||
return torch.zeros_like(latent_image)
|
|
||||||
|
|
||||||
|
|
||||||
class BiasDiff(torch.nn.Module):
|
|
||||||
def __init__(self, bias):
|
|
||||||
super().__init__()
|
|
||||||
self.bias = bias
|
|
||||||
|
|
||||||
def __call__(self, b):
|
|
||||||
return b + self.bias
|
|
||||||
|
|
||||||
def passive_memory_usage(self):
|
|
||||||
return self.bias.nelement() * self.bias.element_size()
|
|
||||||
|
|
||||||
def move_to(self, device):
|
|
||||||
self.to(device=device)
|
|
||||||
return self.passive_memory_usage()
|
|
||||||
|
|
||||||
|
|
||||||
class LoraDiff(torch.nn.Module):
|
|
||||||
def __init__(self, lora_down, lora_up):
|
|
||||||
super().__init__()
|
|
||||||
self.lora_down = lora_down
|
|
||||||
self.lora_up = lora_up
|
|
||||||
|
|
||||||
def __call__(self, w):
|
|
||||||
return w + (self.lora_up @ self.lora_down).reshape(w.shape)
|
|
||||||
|
|
||||||
def passive_memory_usage(self):
|
|
||||||
return self.lora_down.nelement() * self.lora_down.element_size() + self.lora_up.nelement() * self.lora_up.element_size()
|
|
||||||
|
|
||||||
def move_to(self, device):
|
|
||||||
self.to(device=device)
|
|
||||||
return self.passive_memory_usage()
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_process_images(image_files, input_dir, resize_method="None"):
|
|
||||||
"""Utility function to load and process a list of images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_files: List of image filenames
|
|
||||||
input_dir: Base directory containing the images
|
|
||||||
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Batch of processed images
|
|
||||||
"""
|
|
||||||
if not image_files:
|
|
||||||
raise ValueError("No valid images found in input")
|
|
||||||
|
|
||||||
output_images = []
|
|
||||||
w, h = None, None
|
|
||||||
|
|
||||||
for file in image_files:
|
|
||||||
image_path = os.path.join(input_dir, file)
|
|
||||||
img = node_helpers.pillow(Image.open, image_path)
|
|
||||||
|
|
||||||
if img.mode == "I":
|
|
||||||
img = img.point(lambda i: i * (1 / 255))
|
|
||||||
img = img.convert("RGB")
|
|
||||||
|
|
||||||
if w is None and h is None:
|
|
||||||
w, h = img.size[0], img.size[1]
|
|
||||||
|
|
||||||
# Resize image to first image
|
|
||||||
if img.size[0] != w or img.size[1] != h:
|
|
||||||
if resize_method == "Stretch":
|
|
||||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
|
||||||
elif resize_method == "Crop":
|
|
||||||
img = img.crop((0, 0, w, h))
|
|
||||||
elif resize_method == "Pad":
|
|
||||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
|
||||||
elif resize_method == "None":
|
|
||||||
raise ValueError(
|
|
||||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
|
||||||
)
|
|
||||||
|
|
||||||
img_array = np.array(img).astype(np.float32) / 255.0
|
|
||||||
img_tensor = torch.from_numpy(img_array)[None,]
|
|
||||||
output_images.append(img_tensor)
|
|
||||||
|
|
||||||
return torch.cat(output_images, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadImageSetNode:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"images": (
|
|
||||||
[
|
|
||||||
f
|
|
||||||
for f in os.listdir(folder_paths.get_input_directory())
|
|
||||||
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
|
|
||||||
],
|
|
||||||
{"image_upload": True, "allow_batch": True},
|
|
||||||
)
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"resize_method": (
|
|
||||||
["None", "Stretch", "Crop", "Pad"],
|
|
||||||
{"default": "None"},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
INPUT_IS_LIST = True
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "load_images"
|
|
||||||
CATEGORY = "loaders"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def VALIDATE_INPUTS(s, images, resize_method):
|
|
||||||
filenames = images[0] if isinstance(images[0], list) else images
|
|
||||||
|
|
||||||
for image in filenames:
|
|
||||||
if not folder_paths.exists_annotated_filepath(image):
|
|
||||||
return "Invalid image file: {}".format(image)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def load_images(self, input_files, resize_method):
|
|
||||||
input_dir = folder_paths.get_input_directory()
|
|
||||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
|
|
||||||
image_files = [
|
|
||||||
f
|
|
||||||
for f in input_files
|
|
||||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
|
||||||
]
|
|
||||||
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
|
|
||||||
return (output_tensor,)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadImageSetFromFolderNode:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"resize_method": (
|
|
||||||
["None", "Stretch", "Crop", "Pad"],
|
|
||||||
{"default": "None"},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "load_images"
|
|
||||||
CATEGORY = "loaders"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
|
||||||
|
|
||||||
def load_images(self, folder, resize_method):
|
|
||||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
|
||||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
|
||||||
image_files = [
|
|
||||||
f
|
|
||||||
for f in os.listdir(sub_input_dir)
|
|
||||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
|
||||||
]
|
|
||||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
|
|
||||||
return (output_tensor,)
|
|
||||||
|
|
||||||
|
|
||||||
def draw_loss_graph(loss_map, steps):
|
|
||||||
width, height = 500, 300
|
|
||||||
img = Image.new("RGB", (width, height), "white")
|
|
||||||
draw = ImageDraw.Draw(img)
|
|
||||||
|
|
||||||
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
|
||||||
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
|
|
||||||
|
|
||||||
prev_point = (0, height - int(scaled_loss[0] * height))
|
|
||||||
for i, l in enumerate(scaled_loss[1:], start=1):
|
|
||||||
x = int(i / (steps - 1) * width)
|
|
||||||
y = height - int(l * height)
|
|
||||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
|
||||||
prev_point = (x, y)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
class TrainLoraNode:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
|
|
||||||
"vae": (
|
|
||||||
IO.VAE,
|
|
||||||
{
|
|
||||||
"tooltip": "The VAE model to use for encoding images for training."
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"positive": (
|
|
||||||
IO.CONDITIONING,
|
|
||||||
{"tooltip": "The positive conditioning to use for training."},
|
|
||||||
),
|
|
||||||
"image": (
|
|
||||||
IO.IMAGE,
|
|
||||||
{"tooltip": "The image or image batch to train the LoRA on."},
|
|
||||||
),
|
|
||||||
"batch_size": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": 1,
|
|
||||||
"min": 1,
|
|
||||||
"max": 10000,
|
|
||||||
"step": 1,
|
|
||||||
"tooltip": "The batch size to use for training.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"steps": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": 50,
|
|
||||||
"min": 1,
|
|
||||||
"max": 1000,
|
|
||||||
"tooltip": "The number of steps to train the LoRA for.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"learning_rate": (
|
|
||||||
IO.FLOAT,
|
|
||||||
{
|
|
||||||
"default": 0.0003,
|
|
||||||
"min": 0.0000001,
|
|
||||||
"max": 1.0,
|
|
||||||
"step": 0.00001,
|
|
||||||
"tooltip": "The learning rate to use for training.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"rank": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": 8,
|
|
||||||
"min": 1,
|
|
||||||
"max": 128,
|
|
||||||
"tooltip": "The rank of the LoRA layers.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"optimizer": (
|
|
||||||
["Adam", "AdamW", "SGD", "RMSprop"],
|
|
||||||
{
|
|
||||||
"default": "Adam",
|
|
||||||
"tooltip": "The optimizer to use for training.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"loss_function": (
|
|
||||||
["MSE", "L1", "Huber", "SmoothL1"],
|
|
||||||
{
|
|
||||||
"default": "MSE",
|
|
||||||
"tooltip": "The loss function to use for training.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"seed": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 0xFFFFFFFFFFFFFFFF,
|
|
||||||
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"training_dtype": (
|
|
||||||
["bf16", "fp32"],
|
|
||||||
{"default": "bf16", "tooltip": "The dtype to use for training."},
|
|
||||||
),
|
|
||||||
"existing_lora": (
|
|
||||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
|
||||||
{
|
|
||||||
"default": "[None]",
|
|
||||||
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
|
|
||||||
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
|
|
||||||
FUNCTION = "train"
|
|
||||||
CATEGORY = "training"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
vae,
|
|
||||||
positive,
|
|
||||||
image,
|
|
||||||
batch_size,
|
|
||||||
steps,
|
|
||||||
learning_rate,
|
|
||||||
rank,
|
|
||||||
optimizer,
|
|
||||||
loss_function,
|
|
||||||
seed,
|
|
||||||
training_dtype,
|
|
||||||
existing_lora,
|
|
||||||
):
|
|
||||||
num_images = image.shape[0]
|
|
||||||
indices = torch.randperm(num_images)[:batch_size]
|
|
||||||
batch_tensor = image[indices]
|
|
||||||
|
|
||||||
# Ensure we're not in inference mode when encoding
|
|
||||||
encoded = vae.encode(batch_tensor)
|
|
||||||
mp = model.clone()
|
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
|
||||||
mp.set_model_compute_dtype(dtype)
|
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
|
||||||
lora_sd = {}
|
|
||||||
generator = torch.Generator()
|
|
||||||
generator.manual_seed(seed)
|
|
||||||
|
|
||||||
# Load existing LoRA weights if provided
|
|
||||||
existing_weights = {}
|
|
||||||
existing_steps = 0
|
|
||||||
if existing_lora != "[None]":
|
|
||||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
|
||||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
|
||||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
|
||||||
if lora_path:
|
|
||||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
|
||||||
|
|
||||||
for n, m in mp.model.named_modules():
|
|
||||||
if hasattr(m, "weight_function"):
|
|
||||||
if m.weight is not None:
|
|
||||||
key = "{}.weight".format(n)
|
|
||||||
shape = m.weight.shape
|
|
||||||
if len(shape) >= 2:
|
|
||||||
in_dim = math.prod(shape[1:])
|
|
||||||
out_dim = shape[0]
|
|
||||||
|
|
||||||
# Check if we have existing weights for this layer
|
|
||||||
lora_up_key = "{}.lora_up.weight".format(n)
|
|
||||||
lora_down_key = "{}.lora_down.weight".format(n)
|
|
||||||
|
|
||||||
if existing_lora != "[None]" and (
|
|
||||||
lora_up_key in existing_weights
|
|
||||||
and lora_down_key in existing_weights
|
|
||||||
):
|
|
||||||
# Initialize with existing weights
|
|
||||||
lora_up = torch.nn.Parameter(
|
|
||||||
existing_weights[lora_up_key].to(dtype=dtype),
|
|
||||||
requires_grad=True,
|
|
||||||
)
|
|
||||||
lora_down = torch.nn.Parameter(
|
|
||||||
existing_weights[lora_down_key].to(dtype=dtype),
|
|
||||||
requires_grad=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if existing_lora != "[None]":
|
|
||||||
logging.info(f"Warning: No existing weights found for {lora_up_key} or {lora_down_key}")
|
|
||||||
# Initialize new weights
|
|
||||||
lora_down = torch.nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
(
|
|
||||||
rank,
|
|
||||||
in_dim,
|
|
||||||
),
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
requires_grad=True,
|
|
||||||
)
|
|
||||||
lora_up = torch.nn.Parameter(
|
|
||||||
torch.zeros((out_dim, rank), dtype=dtype),
|
|
||||||
requires_grad=True,
|
|
||||||
)
|
|
||||||
torch.nn.init.zeros_(lora_up)
|
|
||||||
torch.nn.init.kaiming_uniform_(
|
|
||||||
lora_down, a=math.sqrt(5), generator=generator
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_sd[lora_up_key] = lora_up
|
|
||||||
lora_sd[lora_down_key] = lora_down
|
|
||||||
mp.add_weight_wrapper(key, LoraDiff(lora_down, lora_up))
|
|
||||||
else:
|
|
||||||
diff = torch.nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
m.weight.shape, dtype=dtype, requires_grad=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
|
||||||
lora_sd["{}.diff".format(n)] = diff
|
|
||||||
if hasattr(m, "bias") and m.bias is not None:
|
|
||||||
key = "{}.bias".format(n)
|
|
||||||
bias = torch.nn.Parameter(
|
|
||||||
torch.zeros(m.bias.shape, dtype=dtype, requires_grad=True)
|
|
||||||
)
|
|
||||||
lora_sd["{}.diff_b".format(n)] = bias
|
|
||||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
|
||||||
|
|
||||||
if optimizer == "Adam":
|
|
||||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "AdamW":
|
|
||||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "SGD":
|
|
||||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "RMSprop":
|
|
||||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
|
||||||
|
|
||||||
# Setup loss function based on selection
|
|
||||||
if loss_function == "MSE":
|
|
||||||
criterion = torch.nn.MSELoss()
|
|
||||||
elif loss_function == "L1":
|
|
||||||
criterion = torch.nn.L1Loss()
|
|
||||||
elif loss_function == "Huber":
|
|
||||||
criterion = torch.nn.HuberLoss()
|
|
||||||
elif loss_function == "SmoothL1":
|
|
||||||
criterion = torch.nn.SmoothL1Loss()
|
|
||||||
|
|
||||||
# Setup sampler and guider like in test script
|
|
||||||
loss_map = {"loss": []}
|
|
||||||
loss_callback = lambda loss: loss_map["loss"].append(loss)
|
|
||||||
train_sampler = TrainSampler(
|
|
||||||
criterion, optimizer, loss_callback=loss_callback
|
|
||||||
)
|
|
||||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
|
||||||
guider.set_conds(positive) # Set conditioning from input
|
|
||||||
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
|
||||||
|
|
||||||
# yoland: this currently resize to the first image in the dataset
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
for step in range(steps):
|
|
||||||
# Generate random sigma
|
|
||||||
sigma = mp.model.model_sampling.percent_to_sigma(
|
|
||||||
torch.rand((1,)).item()
|
|
||||||
)
|
|
||||||
sigma = torch.tensor([sigma])
|
|
||||||
|
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
|
||||||
|
|
||||||
ss.sample(
|
|
||||||
noise, guider, train_sampler, sigma, {"samples": encoded.clone()}
|
|
||||||
)
|
|
||||||
|
|
||||||
return (mp, lora_sd, loss_map, steps + existing_steps)
|
|
||||||
|
|
||||||
|
|
||||||
class SaveLoRA:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"lora": (
|
|
||||||
IO.LORA_MODEL,
|
|
||||||
{
|
|
||||||
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"prefix": (
|
|
||||||
"STRING",
|
|
||||||
{
|
|
||||||
"default": "trained_lora",
|
|
||||||
"tooltip": "The prefix to use for the saved LoRA file.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"steps": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"forceInput": True,
|
|
||||||
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
CATEGORY = "loaders"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
def save(self, lora, prefix, steps=None):
|
|
||||||
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
if steps is None:
|
|
||||||
output_file = f"models/loras/{prefix}_{date}_lora.safetensors"
|
|
||||||
else:
|
|
||||||
output_file = f"models/loras/{prefix}_{steps}_steps_{date}_lora.safetensors"
|
|
||||||
safetensors.torch.save_file(lora, output_file)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class LossGraphNode:
|
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_temp_directory()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"loss": (IO.LOSS_MAP, {"default": {}}),
|
|
||||||
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
|
|
||||||
},
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "plot_loss"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = "training"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
|
|
||||||
|
|
||||||
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
||||||
loss_values = loss["loss"]
|
|
||||||
width, height = 500, 300
|
|
||||||
margin = 40
|
|
||||||
|
|
||||||
img = Image.new(
|
|
||||||
"RGB", (width + margin, height + margin), "white"
|
|
||||||
) # Extend canvas
|
|
||||||
draw = ImageDraw.Draw(img)
|
|
||||||
|
|
||||||
min_loss, max_loss = min(loss_values), max(loss_values)
|
|
||||||
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
|
|
||||||
|
|
||||||
steps = len(loss_values)
|
|
||||||
|
|
||||||
prev_point = (margin, height - int(scaled_loss[0] * height))
|
|
||||||
for i, l in enumerate(scaled_loss[1:], start=1):
|
|
||||||
x = margin + int(i / steps * width) # Scale X properly
|
|
||||||
y = height - int(l * height)
|
|
||||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
|
||||||
prev_point = (x, y)
|
|
||||||
|
|
||||||
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
|
||||||
draw.line(
|
|
||||||
[(margin, height), (width + margin, height)], fill="black", width=2
|
|
||||||
) # X-axis
|
|
||||||
|
|
||||||
font = None
|
|
||||||
try:
|
|
||||||
font = ImageFont.truetype("arial.ttf", 12)
|
|
||||||
except IOError:
|
|
||||||
font = ImageFont.load_default()
|
|
||||||
|
|
||||||
# Add axis labels
|
|
||||||
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
|
||||||
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
|
||||||
|
|
||||||
# Add min/max loss values
|
|
||||||
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
|
||||||
draw.text(
|
|
||||||
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = None
|
|
||||||
if not args.disable_metadata:
|
|
||||||
metadata = PngInfo()
|
|
||||||
if prompt is not None:
|
|
||||||
metadata.add_text("prompt", json.dumps(prompt))
|
|
||||||
if extra_pnginfo is not None:
|
|
||||||
for x in extra_pnginfo:
|
|
||||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
|
||||||
|
|
||||||
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
img.save(
|
|
||||||
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
|
|
||||||
pnginfo=metadata,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"ui": {
|
|
||||||
"images": [
|
|
||||||
{
|
|
||||||
"filename": f"{filename_prefix}_{date}.png",
|
|
||||||
"subfolder": "",
|
|
||||||
"type": "temp",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"TrainLoraNode": TrainLoraNode,
|
|
||||||
"SaveLoRANode": SaveLoRA,
|
|
||||||
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
|
||||||
"LossGraphNode": LossGraphNode,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"TrainLoraNode": "Train LoRA",
|
|
||||||
"SaveLoRANode": "Save LoRA Weights",
|
|
||||||
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
|
||||||
"LossGraphNode": "Plot Loss Graph",
|
|
||||||
}
|
|
||||||
58
execution.py
58
execution.py
@@ -1,34 +1,23 @@
|
|||||||
import copy
|
|
||||||
import heapq
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
import heapq
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import inspect
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import nodes
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy_execution.caching import (
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
CacheKeySetID,
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
CacheKeySetInputSignature,
|
|
||||||
HierarchicalCache,
|
|
||||||
LRUCache,
|
|
||||||
)
|
|
||||||
from comfy_execution.graph import (
|
|
||||||
DynamicPrompt,
|
|
||||||
ExecutionBlocker,
|
|
||||||
ExecutionList,
|
|
||||||
get_input_info,
|
|
||||||
)
|
|
||||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
FAILURE = 1
|
FAILURE = 1
|
||||||
@@ -104,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@@ -566,7 +555,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
@@ -582,7 +571,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = (type_input, extra_info)
|
info = (input_type, extra_info)
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
error = {
|
error = {
|
||||||
@@ -603,8 +592,8 @@ def validate_inputs(prompt, item, validated):
|
|||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
received_type = r[val[1]]
|
received_type = r[val[1]]
|
||||||
received_types[x] = received_type
|
received_types[x] = received_type
|
||||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
|
||||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
"message": "Return type mismatch between linked nodes",
|
"message": "Return type mismatch between linked nodes",
|
||||||
@@ -652,22 +641,22 @@ def validate_inputs(prompt, item, validated):
|
|||||||
val = val["__value__"]
|
val = val["__value__"]
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|
||||||
if type_input == "INT":
|
if input_type == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "FLOAT":
|
if input_type == "FLOAT":
|
||||||
val = float(val)
|
val = float(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "STRING":
|
if input_type == "STRING":
|
||||||
val = str(val)
|
val = str(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "BOOLEAN":
|
if input_type == "BOOLEAN":
|
||||||
val = bool(val)
|
val = bool(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = {
|
error = {
|
||||||
"type": "invalid_input_type",
|
"type": "invalid_input_type",
|
||||||
"message": f"Failed to convert an input value to a {type_input} value",
|
"message": f"Failed to convert an input value to a {input_type} value",
|
||||||
"details": f"{x}, {val}, {ex}",
|
"details": f"{x}, {val}, {ex}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@@ -707,18 +696,19 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(type_input, list):
|
if isinstance(input_type, list):
|
||||||
if val not in type_input:
|
combo_options = input_type
|
||||||
|
if val not in combo_options:
|
||||||
input_config = info
|
input_config = info
|
||||||
list_info = ""
|
list_info = ""
|
||||||
|
|
||||||
# Don't send back gigantic lists like if they're lots of
|
# Don't send back gigantic lists like if they're lots of
|
||||||
# scanned model filepaths
|
# scanned model filepaths
|
||||||
if len(type_input) > 20:
|
if len(combo_options) > 20:
|
||||||
list_info = f"(list of length {len(type_input)})"
|
list_info = f"(list of length {len(combo_options)})"
|
||||||
input_config = None
|
input_config = None
|
||||||
else:
|
else:
|
||||||
list_info = str(type_input)
|
list_info = str(combo_options)
|
||||||
|
|
||||||
error = {
|
error = {
|
||||||
"type": "value_not_in_list",
|
"type": "value_not_in_list",
|
||||||
|
|||||||
@@ -272,9 +272,6 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
|||||||
|
|
||||||
|
|
||||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||||
"""
|
|
||||||
Get the full path of a file in a folder, has to be a file
|
|
||||||
"""
|
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in folder_names_and_paths:
|
if folder_name not in folder_names_and_paths:
|
||||||
@@ -292,9 +289,6 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
"""
|
|
||||||
Get the full path of a file in a folder, has to be a file
|
|
||||||
"""
|
|
||||||
full_path = get_full_path(folder_name, filename)
|
full_path = get_full_path(folder_name, filename)
|
||||||
if full_path is None:
|
if full_path is None:
|
||||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
@@ -396,26 +390,3 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
os.makedirs(full_output_folder, exist_ok=True)
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
counter = 1
|
counter = 1
|
||||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||||
|
|
||||||
def get_input_subfolders() -> list[str]:
|
|
||||||
"""Returns a list of all subfolder paths in the input directory, recursively.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of folder paths relative to the input directory, excluding the root directory
|
|
||||||
"""
|
|
||||||
input_dir = get_input_directory()
|
|
||||||
folders = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not os.path.exists(input_dir):
|
|
||||||
return []
|
|
||||||
|
|
||||||
for root, dirs, _ in os.walk(input_dir):
|
|
||||||
rel_path = os.path.relpath(root, input_dir)
|
|
||||||
if rel_path != ".": # Only include non-root directories
|
|
||||||
# Normalize path separators to forward slashes
|
|
||||||
folders.append(rel_path.replace(os.sep, '/'))
|
|
||||||
|
|
||||||
return sorted(folders)
|
|
||||||
except FileNotFoundError:
|
|
||||||
return []
|
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@@ -2229,7 +2229,6 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_model_downscale.py",
|
"nodes_model_downscale.py",
|
||||||
"nodes_images.py",
|
"nodes_images.py",
|
||||||
"nodes_video_model.py",
|
"nodes_video_model.py",
|
||||||
"nodes_train.py",
|
|
||||||
"nodes_sag.py",
|
"nodes_sag.py",
|
||||||
"nodes_perpneg.py",
|
"nodes_perpneg.py",
|
||||||
"nodes_stable3d.py",
|
"nodes_stable3d.py",
|
||||||
@@ -2268,7 +2267,6 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_lotus.py",
|
"nodes_lotus.py",
|
||||||
"nodes_hunyuan3d.py",
|
"nodes_hunyuan3d.py",
|
||||||
"nodes_primitive.py",
|
"nodes_primitive.py",
|
||||||
"nodes_cfg.py",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from folder_paths import get_input_subfolders, set_input_directory
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mock_folder_structure():
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
# Create a nested folder structure
|
|
||||||
folders = [
|
|
||||||
"folder1",
|
|
||||||
"folder1/subfolder1",
|
|
||||||
"folder1/subfolder2",
|
|
||||||
"folder2",
|
|
||||||
"folder2/deep",
|
|
||||||
"folder2/deep/nested",
|
|
||||||
"empty_folder"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create the folders
|
|
||||||
for folder in folders:
|
|
||||||
os.makedirs(os.path.join(temp_dir, folder))
|
|
||||||
|
|
||||||
# Add some files to test they're not included
|
|
||||||
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
|
|
||||||
f.write("test")
|
|
||||||
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
|
|
||||||
f.write("test")
|
|
||||||
|
|
||||||
set_input_directory(temp_dir)
|
|
||||||
yield temp_dir
|
|
||||||
|
|
||||||
|
|
||||||
def test_gets_all_folders(mock_folder_structure):
|
|
||||||
folders = get_input_subfolders()
|
|
||||||
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
|
|
||||||
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
|
|
||||||
assert sorted(folders) == sorted(expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_handles_nonexistent_input_directory():
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
nonexistent = os.path.join(temp_dir, "nonexistent")
|
|
||||||
set_input_directory(nonexistent)
|
|
||||||
assert get_input_subfolders() == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_input_directory():
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
set_input_directory(temp_dir)
|
|
||||||
assert get_input_subfolders() == [] # Empty since we don't include root
|
|
||||||
Reference in New Issue
Block a user