Compare commits

...

4 Commits

Author SHA1 Message Date
Chenlei Hu
0d68603dfe Update frontend to 1.14.6
Cherry-pick the fix: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3252
2025-03-27 14:02:45 -04:00
comfyanonymous
0a1f8869c9 Add WanFunInpaintToVideo node for the Wan fun inpaint models. 2025-03-27 11:13:27 -04:00
comfyanonymous
3661c833bc Support the WAN 2.1 fun control models.
Use the new WanFunControlToVideo node.
2025-03-26 19:54:54 -04:00
comfyanonymous
84fdaf7b0e Add CFGZeroStar node.
Works on all models that use a negative prompt but is meant for rectified
flow models.
2025-03-26 05:09:52 -04:00
6 changed files with 181 additions and 10 deletions

View File

@@ -992,7 +992,8 @@ class WAN21(BaseModel):
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
@@ -1000,23 +1001,30 @@ class WAN21(BaseModel):
if image is None:
image = torch.zeros_like(noise)
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], 16):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
image = utils.resize_to_batch_size(image, noise.shape[0])
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video:
if not self.image_to_video or extra_channels == image.shape[1]:
return image
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :4]
else:
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
if mask.shape[1] != 4:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = mask.repeat(1, 4, 1, 1, 1)
if mask.shape[1] == 1:
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((mask, image), dim=1)

View File

@@ -969,12 +969,24 @@ class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
"in_dim": 36,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
class WAN21_FunControl2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
"in_dim": 48,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=False, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1013,6 +1025,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models += [SVD_img2vid]

45
comfy_extras/nodes_cfg.py Normal file
View File

@@ -0,0 +1,45 @@
import torch
# https://github.com/WeichenFan/CFG-Zero-star
def optimized_scale(positive, negative):
positive_flat = positive.reshape(positive.shape[0], -1)
negative_flat = negative.reshape(negative.shape[0], -1)
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
class CFGZeroStar:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
def patch(self, model):
m = model.clone()
def cfg_zero_star(args):
guidance_scale = args['cond_scale']
x = args['input']
cond_p = args['cond_denoised']
uncond_p = args['uncond_denoised']
out = args["denoised"]
alpha = optimized_scale(x - cond_p, x - uncond_p)
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
m.set_model_sampler_post_cfg_function(cfg_zero_star)
return (m, )
NODE_CLASS_MAPPINGS = {
"CFGZeroStar": CFGZeroStar
}

View File

@@ -3,6 +3,7 @@ import node_helpers
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
class WanImageToVideo:
@@ -49,6 +50,110 @@ class WanImageToVideo:
return (positive, negative, out_latent)
class WanFunControlToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"control_video": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFunInpaintToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
}

View File

@@ -2267,6 +2267,7 @@ def init_builtin_extra_nodes():
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
"nodes_cfg.py",
]
import_failed = []

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.14.5
comfyui-frontend-package==1.14.6
torch
torchsde
torchvision