Compare commits
3 Commits
v0.3.28
...
annoate_ge
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
522d923948 | ||
|
|
c05c9b552b | ||
|
|
27598702e9 |
@@ -9,14 +9,8 @@ class AppSettings():
|
|||||||
self.user_manager = user_manager
|
self.user_manager = user_manager
|
||||||
|
|
||||||
def get_settings(self, request):
|
def get_settings(self, request):
|
||||||
try:
|
file = self.user_manager.get_request_user_filepath(
|
||||||
file = self.user_manager.get_request_user_filepath(
|
request, "comfy.settings.json")
|
||||||
request,
|
|
||||||
"comfy.settings.json"
|
|
||||||
)
|
|
||||||
except KeyError as e:
|
|
||||||
logging.error("User settings not found.")
|
|
||||||
raise web.HTTPUnauthorized() from e
|
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
try:
|
try:
|
||||||
with open(file) as f:
|
with open(file) as f:
|
||||||
|
|||||||
@@ -79,7 +79,6 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
|
|||||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
|
||||||
|
|
||||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||||
|
|
||||||
@@ -101,7 +100,6 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
|||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
@@ -136,9 +134,8 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u
|
|||||||
class PerformanceFeature(enum.Enum):
|
class PerformanceFeature(enum.Enum):
|
||||||
Fp16Accumulation = "fp16_accumulation"
|
Fp16Accumulation = "fp16_accumulation"
|
||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
|||||||
@@ -110,13 +110,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
|
||||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||||
if embed_shape == 729:
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||||
elif embed_shape == 1024:
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
|
||||||
elif embed_shape == 577:
|
|
||||||
if "multi_modal_projector.linear_1.bias" in sd:
|
if "multi_modal_projector.linear_1.bias" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
{
|
|
||||||
"num_channels": 3,
|
|
||||||
"hidden_act": "gelu_pytorch_tanh",
|
|
||||||
"hidden_size": 1152,
|
|
||||||
"image_size": 512,
|
|
||||||
"intermediate_size": 4304,
|
|
||||||
"model_type": "siglip_vision_model",
|
|
||||||
"num_attention_heads": 16,
|
|
||||||
"num_hidden_layers": 27,
|
|
||||||
"patch_size": 16,
|
|
||||||
"image_mean": [0.5, 0.5, 0.5],
|
|
||||||
"image_std": [0.5, 0.5, 0.5]
|
|
||||||
}
|
|
||||||
@@ -102,13 +102,9 @@ class InputTypeOptions(TypedDict):
|
|||||||
default: bool | str | float | int | list | tuple
|
default: bool | str | float | int | list | tuple
|
||||||
"""The default value of the widget"""
|
"""The default value of the widget"""
|
||||||
defaultInput: bool
|
defaultInput: bool
|
||||||
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
"""Defaults to an input slot rather than a widget"""
|
||||||
- defaultInput on required inputs should be dropped.
|
|
||||||
- defaultInput on optional inputs should be replaced with forceInput.
|
|
||||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
|
||||||
"""
|
|
||||||
forceInput: bool
|
forceInput: bool
|
||||||
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
"""`defaultInput` and also don't allow converting to a widget"""
|
||||||
lazy: bool
|
lazy: bool
|
||||||
"""Declares that this input uses lazy evaluation"""
|
"""Declares that this input uses lazy evaluation"""
|
||||||
rawLink: bool
|
rawLink: bool
|
||||||
|
|||||||
@@ -1422,101 +1422,3 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
|
||||||
'''
|
|
||||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
|
||||||
Arxiv: https://arxiv.org/abs/2305.14267
|
|
||||||
'''
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
seed = extra_args.get("seed", None)
|
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
|
|
||||||
inject_noise = eta > 0 and s_noise > 0
|
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
||||||
if callback is not None:
|
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
||||||
if sigmas[i + 1] == 0:
|
|
||||||
x = denoised
|
|
||||||
else:
|
|
||||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
|
||||||
h = t_next - t
|
|
||||||
h_eta = h * (eta + 1)
|
|
||||||
s = t + r * h
|
|
||||||
fac = 1 / (2 * r)
|
|
||||||
sigma_s = s.neg().exp()
|
|
||||||
|
|
||||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
|
||||||
if inject_noise:
|
|
||||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
|
||||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
|
||||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
|
||||||
|
|
||||||
# Step 1
|
|
||||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
|
||||||
if inject_noise:
|
|
||||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
|
||||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
|
||||||
|
|
||||||
# Step 2
|
|
||||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
|
||||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
|
||||||
if inject_noise:
|
|
||||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
|
||||||
'''
|
|
||||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
|
||||||
Arxiv: https://arxiv.org/abs/2305.14267
|
|
||||||
'''
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
seed = extra_args.get("seed", None)
|
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
|
|
||||||
inject_noise = eta > 0 and s_noise > 0
|
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
||||||
if callback is not None:
|
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
||||||
if sigmas[i + 1] == 0:
|
|
||||||
x = denoised
|
|
||||||
else:
|
|
||||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
|
||||||
h = t_next - t
|
|
||||||
h_eta = h * (eta + 1)
|
|
||||||
s_1 = t + r_1 * h
|
|
||||||
s_2 = t + r_2 * h
|
|
||||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
|
||||||
|
|
||||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
|
||||||
if inject_noise:
|
|
||||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
|
||||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
|
||||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
|
||||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
|
||||||
|
|
||||||
# Step 1
|
|
||||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
|
||||||
if inject_noise:
|
|
||||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
|
||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
|
||||||
|
|
||||||
# Step 2
|
|
||||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
|
||||||
if inject_noise:
|
|
||||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
|
||||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
|
||||||
|
|
||||||
# Step 3
|
|
||||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
|
||||||
if inject_noise:
|
|
||||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
|
||||||
return x
|
|
||||||
|
|||||||
@@ -847,7 +847,6 @@ class SpatialTransformer(nn.Module):
|
|||||||
if not isinstance(context, list):
|
if not isinstance(context, list):
|
||||||
context = [context] * len(self.transformer_blocks)
|
context = [context] * len(self.transformer_blocks)
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
transformer_options["activations_shape"] = list(x.shape)
|
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
if not self.use_linear:
|
if not self.use_linear:
|
||||||
@@ -963,7 +962,6 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
transformer_options={}
|
transformer_options={}
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_, _, h, w = x.shape
|
_, _, h, w = x.shape
|
||||||
transformer_options["activations_shape"] = list(x.shape)
|
|
||||||
x_in = x
|
x_in = x
|
||||||
spatial_context = None
|
spatial_context = None
|
||||||
if exists(context):
|
if exists(context):
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
|
|
||||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||||
@@ -12,13 +11,7 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
|||||||
return sd_out
|
return sd_out
|
||||||
|
|
||||||
|
|
||||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
|
||||||
|
|
||||||
|
|
||||||
def convert_lora(sd):
|
def convert_lora(sd):
|
||||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||||
return convert_lora_bfl_control(sd)
|
return convert_lora_bfl_control(sd)
|
||||||
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
|
||||||
return convert_lora_wan_fun(sd)
|
|
||||||
return sd
|
return sd
|
||||||
|
|||||||
@@ -992,41 +992,31 @@ class WAN21(BaseModel):
|
|||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
noise = kwargs.get("noise", None)
|
noise = kwargs.get("noise", None)
|
||||||
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
|
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
|
||||||
if extra_channels == 0:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image = kwargs.get("concat_latent_image", None)
|
image = kwargs.get("concat_latent_image", None)
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
shape_image = list(noise.shape)
|
image = torch.zeros_like(noise)
|
||||||
shape_image[1] = extra_channels
|
|
||||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
|
||||||
else:
|
|
||||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
||||||
for i in range(0, image.shape[1], 16):
|
|
||||||
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
|
||||||
|
|
||||||
if not self.image_to_video or extra_channels == image.shape[1]:
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
image = self.process_latent_in(image)
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
if not self.image_to_video:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
if image.shape[1] > (extra_channels - 4):
|
|
||||||
image = image[:, :(extra_channels - 4)]
|
|
||||||
|
|
||||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.zeros_like(noise)[:, :4]
|
mask = torch.zeros_like(noise)[:, :4]
|
||||||
else:
|
else:
|
||||||
if mask.shape[1] != 4:
|
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
|
||||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
|
||||||
mask = 1.0 - mask
|
|
||||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
if mask.shape[-3] < noise.shape[-3]:
|
if mask.shape[-3] < noise.shape[-3]:
|
||||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||||
if mask.shape[1] == 1:
|
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
|
||||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
|
||||||
return torch.cat((mask, image), dim=1)
|
return torch.cat((mask, image), dim=1)
|
||||||
|
|||||||
@@ -823,8 +823,6 @@ def text_encoder_dtype(device=None):
|
|||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
elif args.fp16_text_enc:
|
elif args.fp16_text_enc:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
elif args.bf16_text_enc:
|
|
||||||
return torch.bfloat16
|
|
||||||
elif args.fp32_text_enc:
|
elif args.fp32_text_enc:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
@@ -1237,8 +1235,6 @@ def soft_empty_cache(force=False):
|
|||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif is_ascend_npu():
|
elif is_ascend_npu():
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
elif is_mlu():
|
|
||||||
torch.mlu.empty_cache()
|
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|||||||
28
comfy/ops.py
28
comfy/ops.py
@@ -357,25 +357,6 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
|
|
||||||
return scaled_fp8_op
|
return scaled_fp8_op
|
||||||
|
|
||||||
CUBLAS_IS_AVAILABLE = False
|
|
||||||
try:
|
|
||||||
from cublas_ops import CublasLinear
|
|
||||||
CUBLAS_IS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if CUBLAS_IS_AVAILABLE:
|
|
||||||
class cublas_ops(disable_weight_init):
|
|
||||||
class Linear(CublasLinear, disable_weight_init.Linear):
|
|
||||||
def reset_parameters(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
|
||||||
return super().forward(input)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return super().forward(*args, **kwargs)
|
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
@@ -388,15 +369,6 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
):
|
):
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
|
|
||||||
if (
|
|
||||||
PerformanceFeature.CublasOps in args.fast and
|
|
||||||
CUBLAS_IS_AVAILABLE and
|
|
||||||
weight_dtype == torch.float16 and
|
|
||||||
(compute_dtype == torch.float16 or compute_dtype is None)
|
|
||||||
):
|
|
||||||
logging.info("Using cublas ops")
|
|
||||||
return cublas_ops
|
|
||||||
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option
|
|||||||
|
|
||||||
class WrappersMP:
|
class WrappersMP:
|
||||||
OUTER_SAMPLE = "outer_sample"
|
OUTER_SAMPLE = "outer_sample"
|
||||||
PREPARE_SAMPLING = "prepare_sampling"
|
|
||||||
SAMPLER_SAMPLE = "sampler_sample"
|
SAMPLER_SAMPLE = "sampler_sample"
|
||||||
CALC_COND_BATCH = "calc_cond_batch"
|
CALC_COND_BATCH = "calc_cond_batch"
|
||||||
APPLY_MODEL = "apply_model"
|
APPLY_MODEL = "apply_model"
|
||||||
|
|||||||
@@ -106,13 +106,6 @@ def cleanup_additional_models(models):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
|
||||||
_prepare_sampling,
|
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
|
||||||
)
|
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
|
|||||||
@@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
|
"gradient_estimation", "er_sde"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
10
comfy/sd.py
10
comfy/sd.py
@@ -265,7 +265,6 @@ class VAE:
|
|||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
|
||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
@@ -338,7 +337,6 @@ class VAE:
|
|||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = True
|
|
||||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||||
@@ -517,7 +515,7 @@ class VAE:
|
|||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
@@ -546,7 +544,7 @@ class VAE:
|
|||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
args = {}
|
args = {}
|
||||||
if tile_x is not None:
|
if tile_x is not None:
|
||||||
@@ -580,7 +578,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / max(1, memory_used))
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
@@ -614,7 +612,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
|
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
|
|
||||||
args = {}
|
args = {}
|
||||||
if tile_x is not None:
|
if tile_x is not None:
|
||||||
|
|||||||
@@ -969,24 +969,12 @@ class WAN21_I2V(WAN21_T2V):
|
|||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
"model_type": "i2v",
|
"model_type": "i2v",
|
||||||
"in_dim": 36,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.WAN21(self, image_to_video=True, device=device)
|
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21_FunControl2V(WAN21_T2V):
|
|
||||||
unet_config = {
|
|
||||||
"image_model": "wan2.1",
|
|
||||||
"model_type": "i2v",
|
|
||||||
"in_dim": 48,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
|
||||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@@ -1025,6 +1013,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
|||||||
|
|
||||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -316,156 +316,3 @@ class LRUCache(BasicCache):
|
|||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class DependencyAwareCache(BasicCache):
|
|
||||||
"""
|
|
||||||
A cache implementation that tracks dependencies between nodes and manages
|
|
||||||
their execution and caching accordingly. It extends the BasicCache class.
|
|
||||||
Nodes are removed from this cache once all of their descendants have been
|
|
||||||
executed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, key_class):
|
|
||||||
"""
|
|
||||||
Initialize the DependencyAwareCache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_class: The class used for generating cache keys.
|
|
||||||
"""
|
|
||||||
super().__init__(key_class)
|
|
||||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
|
||||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
|
||||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
|
||||||
|
|
||||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
"""
|
|
||||||
Clear the entire cache and rebuild the dependency graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dynprompt: The dynamic prompt object containing node information.
|
|
||||||
node_ids: List of node IDs to initialize the cache for.
|
|
||||||
is_changed_cache: Flag indicating if the cache has changed.
|
|
||||||
"""
|
|
||||||
# Clear all existing cache data
|
|
||||||
self.cache.clear()
|
|
||||||
self.subcaches.clear()
|
|
||||||
self.descendants.clear()
|
|
||||||
self.ancestors.clear()
|
|
||||||
self.executed_nodes.clear()
|
|
||||||
|
|
||||||
# Call the parent method to initialize the cache with the new prompt
|
|
||||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
|
||||||
|
|
||||||
# Rebuild the dependency graph
|
|
||||||
self._build_dependency_graph(dynprompt, node_ids)
|
|
||||||
|
|
||||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
|
||||||
"""
|
|
||||||
Build the dependency graph for all nodes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dynprompt: The dynamic prompt object containing node information.
|
|
||||||
node_ids: List of node IDs to build the graph for.
|
|
||||||
"""
|
|
||||||
self.descendants.clear()
|
|
||||||
self.ancestors.clear()
|
|
||||||
for node_id in node_ids:
|
|
||||||
self.descendants[node_id] = set()
|
|
||||||
self.ancestors[node_id] = set()
|
|
||||||
|
|
||||||
for node_id in node_ids:
|
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
|
||||||
for input_data in inputs.values():
|
|
||||||
if is_link(input_data): # Check if the input is a link to another node
|
|
||||||
ancestor_id = input_data[0]
|
|
||||||
self.descendants[ancestor_id].add(node_id)
|
|
||||||
self.ancestors[node_id].add(ancestor_id)
|
|
||||||
|
|
||||||
def set(self, node_id, value):
|
|
||||||
"""
|
|
||||||
Mark a node as executed and store its value in the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to store.
|
|
||||||
value: The value to store for the node.
|
|
||||||
"""
|
|
||||||
self._set_immediate(node_id, value)
|
|
||||||
self.executed_nodes.add(node_id)
|
|
||||||
self._cleanup_ancestors(node_id)
|
|
||||||
|
|
||||||
def get(self, node_id):
|
|
||||||
"""
|
|
||||||
Retrieve the cached value for a node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The cached value for the node.
|
|
||||||
"""
|
|
||||||
return self._get_immediate(node_id)
|
|
||||||
|
|
||||||
def ensure_subcache_for(self, node_id, children_ids):
|
|
||||||
"""
|
|
||||||
Ensure a subcache exists for a node and update dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the parent node.
|
|
||||||
children_ids: List of child node IDs to associate with the parent node.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The subcache object for the node.
|
|
||||||
"""
|
|
||||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
|
||||||
for child_id in children_ids:
|
|
||||||
self.descendants[node_id].add(child_id)
|
|
||||||
self.ancestors[child_id].add(node_id)
|
|
||||||
return subcache
|
|
||||||
|
|
||||||
def _cleanup_ancestors(self, node_id):
|
|
||||||
"""
|
|
||||||
Check if ancestors of a node can be removed from the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node whose ancestors are to be checked.
|
|
||||||
"""
|
|
||||||
for ancestor_id in self.ancestors.get(node_id, []):
|
|
||||||
if ancestor_id in self.executed_nodes:
|
|
||||||
# Remove ancestor if all its descendants have been executed
|
|
||||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
|
||||||
self._remove_node(ancestor_id)
|
|
||||||
|
|
||||||
def _remove_node(self, node_id):
|
|
||||||
"""
|
|
||||||
Remove a node from the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to remove.
|
|
||||||
"""
|
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
||||||
if cache_key in self.cache:
|
|
||||||
del self.cache[cache_key]
|
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
|
||||||
if subcache_key in self.subcaches:
|
|
||||||
del self.subcaches[subcache_key]
|
|
||||||
|
|
||||||
def clean_unused(self):
|
|
||||||
"""
|
|
||||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
|
||||||
"""
|
|
||||||
Dump the cache and dependency graph for debugging.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list containing the cache state and dependency graph.
|
|
||||||
"""
|
|
||||||
result = super().recursive_debug_dump()
|
|
||||||
result.append({
|
|
||||||
"descendants": self.descendants,
|
|
||||||
"ancestors": self.ancestors,
|
|
||||||
"executed_nodes": list(self.executed_nodes),
|
|
||||||
})
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import nodes
|
from __future__ import annotations
|
||||||
|
from typing import Type, Literal
|
||||||
|
|
||||||
|
import nodes
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
|
|
||||||
class DependencyCycleError(Exception):
|
class DependencyCycleError(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -54,7 +57,22 @@ class DynamicPrompt:
|
|||||||
def get_original_prompt(self):
|
def get_original_prompt(self):
|
||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
def get_input_info(class_def, input_name, valid_inputs=None):
|
def get_input_info(
|
||||||
|
class_def: Type[ComfyNodeABC],
|
||||||
|
input_name: str,
|
||||||
|
valid_inputs: InputTypeDict | None = None
|
||||||
|
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
|
||||||
|
"""Get the input type, category, and extra info for a given input name.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
class_def: The class definition of the node.
|
||||||
|
input_name: The name of the input to get info for.
|
||||||
|
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
|
||||||
|
"""
|
||||||
|
|
||||||
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
||||||
input_info = None
|
input_info = None
|
||||||
input_category = None
|
input_category = None
|
||||||
@@ -126,7 +144,7 @@ class TopologicalSort:
|
|||||||
from_node_id, from_socket = value
|
from_node_id, from_socket = value
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
continue
|
continue
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
node_ids.append(from_node_id)
|
node_ids.append(from_node_id)
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
# https://github.com/WeichenFan/CFG-Zero-star
|
|
||||||
def optimized_scale(positive, negative):
|
|
||||||
positive_flat = positive.reshape(positive.shape[0], -1)
|
|
||||||
negative_flat = negative.reshape(negative.shape[0], -1)
|
|
||||||
|
|
||||||
# Calculate dot production
|
|
||||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
|
||||||
|
|
||||||
# Squared norm of uncondition
|
|
||||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
|
||||||
|
|
||||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
|
||||||
st_star = dot_product / squared_norm
|
|
||||||
|
|
||||||
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
|
||||||
|
|
||||||
class CFGZeroStar:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"model": ("MODEL",),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
RETURN_NAMES = ("patched_model",)
|
|
||||||
FUNCTION = "patch"
|
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
|
|
||||||
def patch(self, model):
|
|
||||||
m = model.clone()
|
|
||||||
def cfg_zero_star(args):
|
|
||||||
guidance_scale = args['cond_scale']
|
|
||||||
x = args['input']
|
|
||||||
cond_p = args['cond_denoised']
|
|
||||||
uncond_p = args['uncond_denoised']
|
|
||||||
out = args["denoised"]
|
|
||||||
alpha = optimized_scale(x - cond_p, x - uncond_p)
|
|
||||||
|
|
||||||
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
|
||||||
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
|
||||||
return (m, )
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"CFGZeroStar": CFGZeroStar
|
|
||||||
}
|
|
||||||
@@ -209,196 +209,6 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
|||||||
vertices = torch.fliplr(vertices)
|
vertices = torch.fliplr(vertices)
|
||||||
return vertices, faces
|
return vertices, faces
|
||||||
|
|
||||||
def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
|
|
||||||
if device is None:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
voxels = voxels.to(device)
|
|
||||||
|
|
||||||
D, H, W = voxels.shape
|
|
||||||
|
|
||||||
padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
|
||||||
z, y, x = torch.meshgrid(
|
|
||||||
torch.arange(D, device=device),
|
|
||||||
torch.arange(H, device=device),
|
|
||||||
torch.arange(W, device=device),
|
|
||||||
indexing='ij'
|
|
||||||
)
|
|
||||||
cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
|
||||||
|
|
||||||
corner_offsets = torch.tensor([
|
|
||||||
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
|
|
||||||
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
|
|
||||||
], device=device)
|
|
||||||
|
|
||||||
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
|
|
||||||
for c, (dz, dy, dx) in enumerate(corner_offsets):
|
|
||||||
corner_values[:, c] = padded[
|
|
||||||
cell_positions[:, 0] + dz,
|
|
||||||
cell_positions[:, 1] + dy,
|
|
||||||
cell_positions[:, 2] + dx
|
|
||||||
]
|
|
||||||
|
|
||||||
corner_signs = corner_values > threshold
|
|
||||||
has_inside = torch.any(corner_signs, dim=1)
|
|
||||||
has_outside = torch.any(~corner_signs, dim=1)
|
|
||||||
contains_surface = has_inside & has_outside
|
|
||||||
|
|
||||||
active_cells = cell_positions[contains_surface]
|
|
||||||
active_signs = corner_signs[contains_surface]
|
|
||||||
active_values = corner_values[contains_surface]
|
|
||||||
|
|
||||||
if active_cells.shape[0] == 0:
|
|
||||||
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
edges = torch.tensor([
|
|
||||||
[0, 1], [0, 2], [0, 4], [1, 3],
|
|
||||||
[1, 5], [2, 3], [2, 6], [3, 7],
|
|
||||||
[4, 5], [4, 6], [5, 7], [6, 7]
|
|
||||||
], device=device)
|
|
||||||
|
|
||||||
cell_vertices = {}
|
|
||||||
progress = comfy.utils.ProgressBar(100)
|
|
||||||
|
|
||||||
for edge_idx, (e1, e2) in enumerate(edges):
|
|
||||||
progress.update(1)
|
|
||||||
crossing = active_signs[:, e1] != active_signs[:, e2]
|
|
||||||
if not crossing.any():
|
|
||||||
continue
|
|
||||||
|
|
||||||
cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
|
|
||||||
|
|
||||||
v1 = active_values[cell_indices, e1]
|
|
||||||
v2 = active_values[cell_indices, e2]
|
|
||||||
|
|
||||||
t = torch.zeros_like(v1, device=device)
|
|
||||||
denom = v2 - v1
|
|
||||||
valid = denom != 0
|
|
||||||
t[valid] = (threshold - v1[valid]) / denom[valid]
|
|
||||||
t[~valid] = 0.5
|
|
||||||
|
|
||||||
p1 = corner_offsets[e1].float()
|
|
||||||
p2 = corner_offsets[e2].float()
|
|
||||||
|
|
||||||
intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
|
|
||||||
|
|
||||||
for i, point in zip(cell_indices.tolist(), intersection):
|
|
||||||
if i not in cell_vertices:
|
|
||||||
cell_vertices[i] = []
|
|
||||||
cell_vertices[i].append(point)
|
|
||||||
|
|
||||||
# Calculate the final vertices as the average of intersection points for each cell
|
|
||||||
vertices = []
|
|
||||||
vertex_lookup = {}
|
|
||||||
|
|
||||||
vert_progress_mod = round(len(cell_vertices)/50)
|
|
||||||
|
|
||||||
for i, points in cell_vertices.items():
|
|
||||||
if not i % vert_progress_mod:
|
|
||||||
progress.update(1)
|
|
||||||
|
|
||||||
if points:
|
|
||||||
vertex = torch.stack(points).mean(dim=0)
|
|
||||||
vertex = vertex + active_cells[i].float()
|
|
||||||
vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
|
|
||||||
vertices.append(vertex)
|
|
||||||
|
|
||||||
if not vertices:
|
|
||||||
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
final_vertices = torch.stack(vertices)
|
|
||||||
|
|
||||||
inside_corners_mask = active_signs
|
|
||||||
outside_corners_mask = ~active_signs
|
|
||||||
|
|
||||||
inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
|
|
||||||
outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
|
|
||||||
|
|
||||||
inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
|
|
||||||
outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
|
|
||||||
|
|
||||||
for i in range(8):
|
|
||||||
mask_inside = inside_corners_mask[:, i].unsqueeze(1)
|
|
||||||
mask_outside = outside_corners_mask[:, i].unsqueeze(1)
|
|
||||||
inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
|
|
||||||
outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
|
|
||||||
|
|
||||||
inside_pos /= inside_counts
|
|
||||||
outside_pos /= outside_counts
|
|
||||||
gradients = inside_pos - outside_pos
|
|
||||||
|
|
||||||
pos_dirs = torch.tensor([
|
|
||||||
[1, 0, 0],
|
|
||||||
[0, 1, 0],
|
|
||||||
[0, 0, 1]
|
|
||||||
], device=device)
|
|
||||||
|
|
||||||
cross_products = [
|
|
||||||
torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
|
|
||||||
for i in range(3) for j in range(i+1, 3)
|
|
||||||
]
|
|
||||||
|
|
||||||
faces = []
|
|
||||||
all_keys = set(vertex_lookup.keys())
|
|
||||||
|
|
||||||
face_progress_mod = round(len(active_cells)/38*3)
|
|
||||||
|
|
||||||
for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
|
|
||||||
dir_i = pos_dirs[i]
|
|
||||||
dir_j = pos_dirs[j]
|
|
||||||
cross_product = cross_products[pair_idx]
|
|
||||||
|
|
||||||
ni_positions = active_cells + dir_i
|
|
||||||
nj_positions = active_cells + dir_j
|
|
||||||
diag_positions = active_cells + dir_i + dir_j
|
|
||||||
|
|
||||||
alignments = torch.matmul(gradients, cross_product)
|
|
||||||
|
|
||||||
valid_quads = []
|
|
||||||
quad_indices = []
|
|
||||||
|
|
||||||
for idx, active_cell in enumerate(active_cells):
|
|
||||||
if not idx % face_progress_mod:
|
|
||||||
progress.update(1)
|
|
||||||
cell_key = tuple(active_cell.tolist())
|
|
||||||
ni_key = tuple(ni_positions[idx].tolist())
|
|
||||||
nj_key = tuple(nj_positions[idx].tolist())
|
|
||||||
diag_key = tuple(diag_positions[idx].tolist())
|
|
||||||
|
|
||||||
if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
|
|
||||||
v0 = vertex_lookup[cell_key]
|
|
||||||
v1 = vertex_lookup[ni_key]
|
|
||||||
v2 = vertex_lookup[nj_key]
|
|
||||||
v3 = vertex_lookup[diag_key]
|
|
||||||
|
|
||||||
valid_quads.append((v0, v1, v2, v3))
|
|
||||||
quad_indices.append(idx)
|
|
||||||
|
|
||||||
for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
|
|
||||||
cell_idx = quad_indices[q_idx]
|
|
||||||
if alignments[cell_idx] > 0:
|
|
||||||
faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
|
|
||||||
faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
|
|
||||||
else:
|
|
||||||
faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
|
|
||||||
faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
|
|
||||||
|
|
||||||
if faces:
|
|
||||||
faces = torch.stack(faces)
|
|
||||||
else:
|
|
||||||
faces = torch.zeros((0, 3), dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
v_min = 0
|
|
||||||
v_max = max(D, H, W)
|
|
||||||
|
|
||||||
final_vertices = final_vertices - (v_min + v_max) / 2
|
|
||||||
|
|
||||||
scale = (v_max - v_min) / 2
|
|
||||||
if scale > 0:
|
|
||||||
final_vertices = final_vertices / scale
|
|
||||||
|
|
||||||
final_vertices = torch.fliplr(final_vertices)
|
|
||||||
|
|
||||||
return final_vertices, faces
|
|
||||||
|
|
||||||
class MESH:
|
class MESH:
|
||||||
def __init__(self, vertices, faces):
|
def __init__(self, vertices, faces):
|
||||||
@@ -427,34 +237,6 @@ class VoxelToMeshBasic:
|
|||||||
|
|
||||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||||
|
|
||||||
class VoxelToMesh:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"voxel": ("VOXEL", ),
|
|
||||||
"algorithm": (["surface net", "basic"], ),
|
|
||||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("MESH",)
|
|
||||||
FUNCTION = "decode"
|
|
||||||
|
|
||||||
CATEGORY = "3d"
|
|
||||||
|
|
||||||
def decode(self, voxel, algorithm, threshold):
|
|
||||||
vertices = []
|
|
||||||
faces = []
|
|
||||||
|
|
||||||
if algorithm == "basic":
|
|
||||||
mesh_function = voxel_to_mesh
|
|
||||||
elif algorithm == "surface net":
|
|
||||||
mesh_function = voxel_to_mesh_surfnet
|
|
||||||
|
|
||||||
for x in voxel.data:
|
|
||||||
v, f = mesh_function(x, threshold=threshold, device=None)
|
|
||||||
vertices.append(v)
|
|
||||||
faces.append(f)
|
|
||||||
|
|
||||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
|
||||||
|
|
||||||
|
|
||||||
def save_glb(vertices, faces, filepath, metadata=None):
|
def save_glb(vertices, faces, filepath, metadata=None):
|
||||||
"""
|
"""
|
||||||
@@ -462,7 +244,7 @@ def save_glb(vertices, faces, filepath, metadata=None):
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
||||||
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
|
faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces)
|
||||||
filepath: str - Output filepath (should end with .glb)
|
filepath: str - Output filepath (should end with .glb)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -629,6 +411,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
|
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
|
||||||
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
|
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
|
||||||
"VoxelToMeshBasic": VoxelToMeshBasic,
|
"VoxelToMeshBasic": VoxelToMeshBasic,
|
||||||
"VoxelToMesh": VoxelToMesh,
|
|
||||||
"SaveGLB": SaveGLB,
|
"SaveGLB": SaveGLB,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -446,9 +446,10 @@ class LTXVPreprocess:
|
|||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def preprocess(self, image, img_compression):
|
def preprocess(self, image, img_compression):
|
||||||
output_images = []
|
if img_compression > 0:
|
||||||
for i in range(image.shape[0]):
|
output_images = []
|
||||||
output_images.append(preprocess(image[i], img_compression))
|
for i in range(image.shape[0]):
|
||||||
|
output_images.append(preprocess(image[i], img_compression))
|
||||||
return (torch.stack(output_images),)
|
return (torch.stack(output_images),)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import numpy as np
|
|||||||
import scipy.ndimage
|
import scipy.ndimage
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
|
||||||
|
|
||||||
from nodes import MAX_RESOLUTION
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
@@ -88,7 +87,6 @@ class ImageCompositeMasked:
|
|||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
def composite(self, destination, source, x, y, resize_source, mask = None):
|
||||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
|
||||||
destination = destination.clone().movedim(-1, 1)
|
destination = destination.clone().movedim(-1, 1)
|
||||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||||
return (output,)
|
return (output,)
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
# from https://github.com/bebebe666/OptimalSteps
|
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
def loglinear_interp(t_steps, num_steps):
|
|
||||||
"""
|
|
||||||
Performs log-linear interpolation of a given array of decreasing numbers.
|
|
||||||
"""
|
|
||||||
xs = np.linspace(0, 1, len(t_steps))
|
|
||||||
ys = np.log(t_steps[::-1])
|
|
||||||
|
|
||||||
new_xs = np.linspace(0, 1, num_steps)
|
|
||||||
new_ys = np.interp(new_xs, xs, ys)
|
|
||||||
|
|
||||||
interped_ys = np.exp(new_ys)[::-1].copy()
|
|
||||||
return interped_ys
|
|
||||||
|
|
||||||
|
|
||||||
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
|
|
||||||
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
|
|
||||||
}
|
|
||||||
|
|
||||||
class OptimalStepsScheduler:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required":
|
|
||||||
{"model_type": (["FLUX", "Wan"], ),
|
|
||||||
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
|
|
||||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ("SIGMAS",)
|
|
||||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
|
||||||
|
|
||||||
FUNCTION = "get_sigmas"
|
|
||||||
|
|
||||||
def get_sigmas(self, model_type, steps, denoise):
|
|
||||||
total_steps = steps
|
|
||||||
if denoise < 1.0:
|
|
||||||
if denoise <= 0.0:
|
|
||||||
return (torch.FloatTensor([]),)
|
|
||||||
total_steps = round(steps * denoise)
|
|
||||||
|
|
||||||
sigmas = NOISE_LEVELS[model_type][:]
|
|
||||||
if (steps + 1) != len(sigmas):
|
|
||||||
sigmas = loglinear_interp(sigmas, steps + 1)
|
|
||||||
|
|
||||||
sigmas = sigmas[-(total_steps + 1):]
|
|
||||||
sigmas[-1] = 0
|
|
||||||
return (torch.FloatTensor(sigmas), )
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"OptimalStepsScheduler": OptimalStepsScheduler,
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,7 @@ import math
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import node_helpers
|
|
||||||
|
|
||||||
class Blend:
|
class Blend:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -34,7 +34,6 @@ class Blend:
|
|||||||
CATEGORY = "image/postprocessing"
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||||
image1, image2 = node_helpers.image_alpha_fix(image1, image2)
|
|
||||||
image2 = image2.to(image1.device)
|
image2 = image2.to(image1.device)
|
||||||
if image1.shape != image2.shape:
|
if image1.shape != image2.shape:
|
||||||
image2 = image2.permute(0, 3, 1, 2)
|
image2 = image2.permute(0, 3, 1, 2)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import node_helpers
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.latent_formats
|
|
||||||
|
|
||||||
|
|
||||||
class WanImageToVideo:
|
class WanImageToVideo:
|
||||||
@@ -50,110 +49,6 @@ class WanImageToVideo:
|
|||||||
return (positive, negative, out_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanFunControlToVideo:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
|
||||||
"negative": ("CONDITIONING", ),
|
|
||||||
"vae": ("VAE", ),
|
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
|
||||||
},
|
|
||||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
|
||||||
"start_image": ("IMAGE", ),
|
|
||||||
"control_video": ("IMAGE", ),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
|
||||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
|
||||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
|
||||||
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
|
|
||||||
|
|
||||||
if start_image is not None:
|
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
|
||||||
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
|
||||||
|
|
||||||
if control_video is not None:
|
|
||||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
|
||||||
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
|
||||||
|
|
||||||
if clip_vision_output is not None:
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
|
||||||
|
|
||||||
out_latent = {}
|
|
||||||
out_latent["samples"] = latent
|
|
||||||
return (positive, negative, out_latent)
|
|
||||||
|
|
||||||
class WanFunInpaintToVideo:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
|
||||||
"negative": ("CONDITIONING", ),
|
|
||||||
"vae": ("VAE", ),
|
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
|
||||||
},
|
|
||||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
|
||||||
"start_image": ("IMAGE", ),
|
|
||||||
"end_image": ("IMAGE", ),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
|
||||||
if start_image is not None:
|
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
if end_image is not None:
|
|
||||||
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
|
|
||||||
image = torch.ones((length, height, width, 3)) * 0.5
|
|
||||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
|
||||||
|
|
||||||
if start_image is not None:
|
|
||||||
image[:start_image.shape[0]] = start_image
|
|
||||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
|
||||||
|
|
||||||
if end_image is not None:
|
|
||||||
image[-end_image.shape[0]:] = end_image
|
|
||||||
mask[:, :, -end_image.shape[0]:] = 0.0
|
|
||||||
|
|
||||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
|
||||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
|
|
||||||
if clip_vision_output is not None:
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
|
||||||
|
|
||||||
out_latent = {}
|
|
||||||
out_latent["samples"] = latent
|
|
||||||
return (positive, negative, out_latent)
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.28"
|
__version__ = "0.3.27"
|
||||||
|
|||||||
86
execution.py
86
execution.py
@@ -15,7 +15,7 @@ import nodes
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@@ -59,45 +59,27 @@ class IsChangedCache:
|
|||||||
self.is_changed[node_id] = node["is_changed"]
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
|
||||||
class CacheType(Enum):
|
|
||||||
CLASSIC = 0
|
|
||||||
LRU = 1
|
|
||||||
DEPENDENCY_AWARE = 2
|
|
||||||
|
|
||||||
|
|
||||||
class CacheSet:
|
class CacheSet:
|
||||||
def __init__(self, cache_type=None, cache_size=None):
|
def __init__(self, lru_size=None):
|
||||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
if lru_size is None or lru_size == 0:
|
||||||
self.init_dependency_aware_cache()
|
|
||||||
logging.info("Disabling intermediate node cache.")
|
|
||||||
elif cache_type == CacheType.LRU:
|
|
||||||
if cache_size is None:
|
|
||||||
cache_size = 0
|
|
||||||
self.init_lru_cache(cache_size)
|
|
||||||
logging.info("Using LRU cache")
|
|
||||||
else:
|
|
||||||
self.init_classic_cache()
|
self.init_classic_cache()
|
||||||
|
else:
|
||||||
|
self.init_lru_cache(lru_size)
|
||||||
self.all = [self.outputs, self.ui, self.objects]
|
self.all = [self.outputs, self.ui, self.objects]
|
||||||
|
|
||||||
|
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||||
|
# blowing away the cache every time
|
||||||
|
def init_lru_cache(self, cache_size):
|
||||||
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_lru_cache(self, cache_size):
|
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
|
||||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
|
||||||
|
|
||||||
# only hold cached items while the decendents have not executed
|
|
||||||
def init_dependency_aware_cache(self):
|
|
||||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
|
||||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
|
||||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def recursive_debug_dump(self):
|
||||||
result = {
|
result = {
|
||||||
"outputs": self.outputs.recursive_debug_dump(),
|
"outputs": self.outputs.recursive_debug_dump(),
|
||||||
@@ -111,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@@ -432,14 +414,13 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server, cache_type=False, cache_size=None):
|
def __init__(self, server, lru_size=None):
|
||||||
self.cache_size = cache_size
|
self.lru_size = lru_size
|
||||||
self.cache_type = cache_type
|
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
self.caches = CacheSet(self.lru_size)
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
|
||||||
@@ -574,7 +555,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
@@ -590,7 +571,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = (type_input, extra_info)
|
info = (input_type, extra_info)
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
error = {
|
error = {
|
||||||
@@ -611,8 +592,8 @@ def validate_inputs(prompt, item, validated):
|
|||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
received_type = r[val[1]]
|
received_type = r[val[1]]
|
||||||
received_types[x] = received_type
|
received_types[x] = received_type
|
||||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
|
||||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
"message": "Return type mismatch between linked nodes",
|
"message": "Return type mismatch between linked nodes",
|
||||||
@@ -660,22 +641,22 @@ def validate_inputs(prompt, item, validated):
|
|||||||
val = val["__value__"]
|
val = val["__value__"]
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|
||||||
if type_input == "INT":
|
if input_type == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "FLOAT":
|
if input_type == "FLOAT":
|
||||||
val = float(val)
|
val = float(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "STRING":
|
if input_type == "STRING":
|
||||||
val = str(val)
|
val = str(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "BOOLEAN":
|
if input_type == "BOOLEAN":
|
||||||
val = bool(val)
|
val = bool(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = {
|
error = {
|
||||||
"type": "invalid_input_type",
|
"type": "invalid_input_type",
|
||||||
"message": f"Failed to convert an input value to a {type_input} value",
|
"message": f"Failed to convert an input value to a {input_type} value",
|
||||||
"details": f"{x}, {val}, {ex}",
|
"details": f"{x}, {val}, {ex}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@@ -715,18 +696,19 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(type_input, list):
|
if isinstance(input_type, list):
|
||||||
if val not in type_input:
|
combo_options = input_type
|
||||||
|
if val not in combo_options:
|
||||||
input_config = info
|
input_config = info
|
||||||
list_info = ""
|
list_info = ""
|
||||||
|
|
||||||
# Don't send back gigantic lists like if they're lots of
|
# Don't send back gigantic lists like if they're lots of
|
||||||
# scanned model filepaths
|
# scanned model filepaths
|
||||||
if len(type_input) > 20:
|
if len(combo_options) > 20:
|
||||||
list_info = f"(list of length {len(type_input)})"
|
list_info = f"(list of length {len(combo_options)})"
|
||||||
input_config = None
|
input_config = None
|
||||||
else:
|
else:
|
||||||
list_info = str(type_input)
|
list_info = str(combo_options)
|
||||||
|
|
||||||
error = {
|
error = {
|
||||||
"type": "value_not_in_list",
|
"type": "value_not_in_list",
|
||||||
@@ -794,7 +776,7 @@ def validate_prompt(prompt):
|
|||||||
"details": f"Node ID '#{x}'",
|
"details": f"Node ID '#{x}'",
|
||||||
"extra_info": {}
|
"extra_info": {}
|
||||||
}
|
}
|
||||||
return (False, error, [], {})
|
return (False, error, [], [])
|
||||||
|
|
||||||
class_type = prompt[x]['class_type']
|
class_type = prompt[x]['class_type']
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
||||||
@@ -805,7 +787,7 @@ def validate_prompt(prompt):
|
|||||||
"details": f"Node ID '#{x}'",
|
"details": f"Node ID '#{x}'",
|
||||||
"extra_info": {}
|
"extra_info": {}
|
||||||
}
|
}
|
||||||
return (False, error, [], {})
|
return (False, error, [], [])
|
||||||
|
|
||||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||||
outputs.add(x)
|
outputs.add(x)
|
||||||
@@ -817,7 +799,7 @@ def validate_prompt(prompt):
|
|||||||
"details": "",
|
"details": "",
|
||||||
"extra_info": {}
|
"extra_info": {}
|
||||||
}
|
}
|
||||||
return (False, error, [], {})
|
return (False, error, [], [])
|
||||||
|
|
||||||
good_outputs = set()
|
good_outputs = set()
|
||||||
errors = []
|
errors = []
|
||||||
|
|||||||
@@ -85,7 +85,6 @@ cache_helper = CacheHelper()
|
|||||||
|
|
||||||
extension_mimetypes_cache = {
|
extension_mimetypes_cache = {
|
||||||
"webp" : "image",
|
"webp" : "image",
|
||||||
"fbx" : "model",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def map_legacy(folder_name: str) -> str:
|
def map_legacy(folder_name: str) -> str:
|
||||||
@@ -141,14 +140,11 @@ def get_directory_by_type(type_name: str) -> str | None:
|
|||||||
return get_input_directory()
|
return get_input_directory()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]:
|
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Example:
|
Example:
|
||||||
files = os.listdir(folder_paths.get_input_directory())
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
videos = filter_files_content_types(files, ["video"])
|
filter_files_content_types(files, ["image", "audio", "video"])
|
||||||
|
|
||||||
Note:
|
|
||||||
- 'model' in MIME context refers to 3D models, not files containing trained weights and parameters
|
|
||||||
"""
|
"""
|
||||||
global extension_mimetypes_cache
|
global extension_mimetypes_cache
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
10
main.py
10
main.py
@@ -10,7 +10,6 @@ from app.logger import setup_logger
|
|||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
|
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
|
||||||
@@ -157,13 +156,7 @@ def cuda_malloc_warning():
|
|||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
cache_type = execution.CacheType.CLASSIC
|
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||||
if args.cache_lru > 0:
|
|
||||||
cache_type = execution.CacheType.LRU
|
|
||||||
elif args.cache_none:
|
|
||||||
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
|
||||||
|
|
||||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
@@ -302,7 +295,6 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Running directly, just start ComfyUI.
|
# Running directly, just start ComfyUI.
|
||||||
logging.info("Python version: {}".format(sys.version))
|
|
||||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
|
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
|
|||||||
@@ -44,11 +44,3 @@ def string_to_torch_dtype(string):
|
|||||||
return torch.float16
|
return torch.float16
|
||||||
if string == "bf16":
|
if string == "bf16":
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
|
|
||||||
def image_alpha_fix(destination, source):
|
|
||||||
if destination.shape[-1] < source.shape[-1]:
|
|
||||||
source = source[...,:destination.shape[-1]]
|
|
||||||
elif destination.shape[-1] > source.shape[-1]:
|
|
||||||
destination = torch.nn.functional.pad(destination, (0, 1))
|
|
||||||
destination[..., -1] = 1.0
|
|
||||||
return destination, source
|
|
||||||
|
|||||||
22
nodes.py
22
nodes.py
@@ -786,8 +786,6 @@ class ControlNetLoader:
|
|||||||
def load_controlnet(self, control_net_name):
|
def load_controlnet(self, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||||
if controlnet is None:
|
|
||||||
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
|
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
class DiffControlNetLoader:
|
class DiffControlNetLoader:
|
||||||
@@ -1008,8 +1006,6 @@ class CLIPVisionLoader:
|
|||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
||||||
clip_vision = comfy.clip_vision.load(clip_path)
|
clip_vision = comfy.clip_vision.load(clip_path)
|
||||||
if clip_vision is None:
|
|
||||||
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
|
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
class CLIPVisionEncode:
|
class CLIPVisionEncode:
|
||||||
@@ -1654,7 +1650,6 @@ class LoadImage:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
files = folder_paths.filter_files_content_types(files, ["image"])
|
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(files), {"image_upload": True})},
|
{"image": (sorted(files), {"image_upload": True})},
|
||||||
}
|
}
|
||||||
@@ -1693,9 +1688,6 @@ class LoadImage:
|
|||||||
if 'A' in i.getbands():
|
if 'A' in i.getbands():
|
||||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||||
mask = 1. - torch.from_numpy(mask)
|
mask = 1. - torch.from_numpy(mask)
|
||||||
elif i.mode == 'P' and 'transparency' in i.info:
|
|
||||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
|
||||||
mask = 1. - torch.from_numpy(mask)
|
|
||||||
else:
|
else:
|
||||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
@@ -2131,25 +2123,21 @@ def get_module_name(module_path: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||||
module_name = get_module_name(module_path)
|
module_name = os.path.basename(module_path)
|
||||||
if os.path.isfile(module_path):
|
if os.path.isfile(module_path):
|
||||||
sp = os.path.splitext(module_path)
|
sp = os.path.splitext(module_path)
|
||||||
module_name = sp[0]
|
module_name = sp[0]
|
||||||
sys_module_name = module_name
|
|
||||||
elif os.path.isdir(module_path):
|
|
||||||
sys_module_name = module_path.replace(".", "_x_")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.debug("Trying to load custom node {}".format(module_path))
|
logging.debug("Trying to load custom node {}".format(module_path))
|
||||||
if os.path.isfile(module_path):
|
if os.path.isfile(module_path):
|
||||||
module_spec = importlib.util.spec_from_file_location(sys_module_name, module_path)
|
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||||
module_dir = os.path.split(module_path)[0]
|
module_dir = os.path.split(module_path)[0]
|
||||||
else:
|
else:
|
||||||
module_spec = importlib.util.spec_from_file_location(sys_module_name, os.path.join(module_path, "__init__.py"))
|
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
|
||||||
module_dir = module_path
|
module_dir = module_path
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(module_spec)
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
sys.modules[sys_module_name] = module
|
sys.modules[module_name] = module
|
||||||
module_spec.loader.exec_module(module)
|
module_spec.loader.exec_module(module)
|
||||||
|
|
||||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||||
@@ -2279,8 +2267,6 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_lotus.py",
|
"nodes_lotus.py",
|
||||||
"nodes_hunyuan3d.py",
|
"nodes_hunyuan3d.py",
|
||||||
"nodes_primitive.py",
|
"nodes_primitive.py",
|
||||||
"nodes_cfg.py",
|
|
||||||
"nodes_optimalsteps.py"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.28"
|
version = "0.3.27"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.15.13
|
comfyui-frontend-package==1.14.5
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
10
server.py
10
server.py
@@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
|
|||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
if request.path.endswith('.js') or request.path.endswith('.css'):
|
||||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -657,13 +657,7 @@ class PromptServer():
|
|||||||
logging.warning("invalid prompt: {}".format(valid[1]))
|
logging.warning("invalid prompt: {}".format(valid[1]))
|
||||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||||
else:
|
else:
|
||||||
error = {
|
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||||
"type": "no_prompt",
|
|
||||||
"message": "No prompt provided",
|
|
||||||
"details": "No prompt provided",
|
|
||||||
"extra_info": {}
|
|
||||||
}
|
|
||||||
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
|
||||||
|
|
||||||
@routes.post("/queue")
|
@routes.post("/queue")
|
||||||
async def post_queue(request):
|
async def post_queue(request):
|
||||||
|
|||||||
@@ -1,17 +1,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from folder_paths import filter_files_content_types, extension_mimetypes_cache
|
from folder_paths import filter_files_content_types
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def file_extensions():
|
def file_extensions():
|
||||||
return {
|
return {
|
||||||
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||||
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||||
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'],
|
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||||
'model': ['gltf', 'glb', 'obj', 'fbx', 'stl']
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -25,18 +22,7 @@ def mock_dir(file_extensions):
|
|||||||
yield directory
|
yield directory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def test_categorizes_all_correctly(mock_dir, file_extensions):
|
||||||
def patched_mimetype_cache(file_extensions):
|
|
||||||
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
|
|
||||||
new_cache = extension_mimetypes_cache.copy()
|
|
||||||
for extension in file_extensions["model"]:
|
|
||||||
new_cache[extension] = "model"
|
|
||||||
|
|
||||||
with patch("folder_paths.extension_mimetypes_cache", new_cache):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
|
|
||||||
files = os.listdir(mock_dir)
|
files = os.listdir(mock_dir)
|
||||||
for content_type, extensions in file_extensions.items():
|
for content_type, extensions in file_extensions.items():
|
||||||
filtered_files = filter_files_content_types(files, [content_type])
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
@@ -44,7 +30,7 @@ def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_c
|
|||||||
assert f"sample_{content_type}.{extension}" in filtered_files
|
assert f"sample_{content_type}.{extension}" in filtered_files
|
||||||
|
|
||||||
|
|
||||||
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
|
def test_categorizes_all_uniquely(mock_dir, file_extensions):
|
||||||
files = os.listdir(mock_dir)
|
files = os.listdir(mock_dir)
|
||||||
for content_type, extensions in file_extensions.items():
|
for content_type, extensions in file_extensions.items():
|
||||||
filtered_files = filter_files_content_types(files, [content_type])
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
|||||||
Reference in New Issue
Block a user