Compare commits
8 Commits
v0.3.27
...
huchenlei-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d68603dfe | ||
|
|
0a1f8869c9 | ||
|
|
3661c833bc | ||
|
|
84fdaf7b0e | ||
|
|
8edc1f44c1 | ||
|
|
eade1551bb | ||
|
|
581a9991ff | ||
|
|
e471c726e5 |
@@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||||
|
- 3D Models
|
||||||
|
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
|
|||||||
@@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout="HND"
|
tensor_layout = "HND"
|
||||||
else:
|
else:
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
@@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
lambda t: t.view(b, -1, heads, dim_head),
|
lambda t: t.view(b, -1, heads, dim_head),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
tensor_layout="NHD"
|
tensor_layout = "NHD"
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# add a batch dimension if there isn't already one
|
# add a batch dimension if there isn't already one
|
||||||
@@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
try:
|
||||||
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||||
|
if tensor_layout == "NHD":
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
||||||
|
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
|
|||||||
@@ -992,7 +992,8 @@ class WAN21(BaseModel):
|
|||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
noise = kwargs.get("noise", None)
|
noise = kwargs.get("noise", None)
|
||||||
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
|
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
|
||||||
|
if extra_channels == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image = kwargs.get("concat_latent_image", None)
|
image = kwargs.get("concat_latent_image", None)
|
||||||
@@ -1000,23 +1001,30 @@ class WAN21(BaseModel):
|
|||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
image = torch.zeros_like(noise)
|
image = torch.zeros_like(noise)
|
||||||
|
shape_image = list(noise.shape)
|
||||||
|
shape_image[1] = extra_channels
|
||||||
|
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||||
|
else:
|
||||||
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
for i in range(0, image.shape[1], 16):
|
||||||
|
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
if not self.image_to_video or extra_channels == image.shape[1]:
|
||||||
image = self.process_latent_in(image)
|
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
|
||||||
|
|
||||||
if not self.image_to_video:
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.zeros_like(noise)[:, :4]
|
mask = torch.zeros_like(noise)[:, :4]
|
||||||
else:
|
else:
|
||||||
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
|
if mask.shape[1] != 4:
|
||||||
|
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||||
|
mask = 1.0 - mask
|
||||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
if mask.shape[-3] < noise.shape[-3]:
|
if mask.shape[-3] < noise.shape[-3]:
|
||||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
if mask.shape[1] == 1:
|
||||||
|
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
|
||||||
return torch.cat((mask, image), dim=1)
|
return torch.cat((mask, image), dim=1)
|
||||||
|
|||||||
@@ -46,6 +46,32 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
|
def get_supported_float8_types():
|
||||||
|
float8_types = []
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e4m3fn)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e4m3fnuz)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e5m2)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e5m2fnuz)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e8m0fnu)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return float8_types
|
||||||
|
|
||||||
|
FLOAT8_TYPES = get_supported_float8_types()
|
||||||
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
torch_version = ""
|
torch_version = ""
|
||||||
try:
|
try:
|
||||||
@@ -701,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
|
|
||||||
fp8_dtype = None
|
fp8_dtype = None
|
||||||
try:
|
if weight_dtype in FLOAT8_TYPES:
|
||||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
fp8_dtype = weight_dtype
|
||||||
fp8_dtype = weight_dtype
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if fp8_dtype is not None:
|
if fp8_dtype is not None:
|
||||||
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||||
|
|||||||
@@ -969,12 +969,24 @@ class WAN21_I2V(WAN21_T2V):
|
|||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
"model_type": "i2v",
|
"model_type": "i2v",
|
||||||
|
"in_dim": 36,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.WAN21(self, image_to_video=True, device=device)
|
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_FunControl2V(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "i2v",
|
||||||
|
"in_dim": 48,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@@ -1013,6 +1025,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
|||||||
|
|
||||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
45
comfy_extras/nodes_cfg.py
Normal file
45
comfy_extras/nodes_cfg.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
# https://github.com/WeichenFan/CFG-Zero-star
|
||||||
|
def optimized_scale(positive, negative):
|
||||||
|
positive_flat = positive.reshape(positive.shape[0], -1)
|
||||||
|
negative_flat = negative.reshape(negative.shape[0], -1)
|
||||||
|
|
||||||
|
# Calculate dot production
|
||||||
|
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
# Squared norm of uncondition
|
||||||
|
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||||
|
|
||||||
|
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||||
|
st_star = dot_product / squared_norm
|
||||||
|
|
||||||
|
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
||||||
|
|
||||||
|
class CFGZeroStar:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"model": ("MODEL",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
RETURN_NAMES = ("patched_model",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
CATEGORY = "advanced/guidance"
|
||||||
|
|
||||||
|
def patch(self, model):
|
||||||
|
m = model.clone()
|
||||||
|
def cfg_zero_star(args):
|
||||||
|
guidance_scale = args['cond_scale']
|
||||||
|
x = args['input']
|
||||||
|
cond_p = args['cond_denoised']
|
||||||
|
uncond_p = args['uncond_denoised']
|
||||||
|
out = args["denoised"]
|
||||||
|
alpha = optimized_scale(x - cond_p, x - uncond_p)
|
||||||
|
|
||||||
|
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
||||||
|
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"CFGZeroStar": CFGZeroStar
|
||||||
|
}
|
||||||
@@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
|||||||
|
|
||||||
return {"required": arg_dict}
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["patch_embedding."] = argument
|
||||||
|
arg_dict["time_embedding."] = argument
|
||||||
|
arg_dict["time_projection."] = argument
|
||||||
|
arg_dict["text_embedding."] = argument
|
||||||
|
arg_dict["img_emb."] = argument
|
||||||
|
|
||||||
|
for i in range(40):
|
||||||
|
arg_dict["blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["head."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSD1": ModelMergeSD1,
|
"ModelMergeSD1": ModelMergeSD1,
|
||||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||||
@@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelMergeLTXV": ModelMergeLTXV,
|
"ModelMergeLTXV": ModelMergeLTXV,
|
||||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||||
|
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import node_helpers
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.latent_formats
|
||||||
|
|
||||||
|
|
||||||
class WanImageToVideo:
|
class WanImageToVideo:
|
||||||
@@ -49,6 +50,110 @@ class WanImageToVideo:
|
|||||||
return (positive, negative, out_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
class WanFunControlToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"start_image": ("IMAGE", ),
|
||||||
|
"control_video": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||||
|
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||||
|
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||||
|
|
||||||
|
if control_video is not None:
|
||||||
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
||||||
|
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
class WanFunInpaintToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"start_image": ("IMAGE", ),
|
||||||
|
"end_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if end_image is not None:
|
||||||
|
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|
||||||
|
image = torch.ones((length, height, width, 3)) * 0.5
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
image[:start_image.shape[0]] = start_image
|
||||||
|
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||||
|
|
||||||
|
if end_image is not None:
|
||||||
|
image[-end_image.shape[0]:] = end_image
|
||||||
|
mask[:, :, -end_image.shape[0]:] = 0.0
|
||||||
|
|
||||||
|
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||||
|
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
|
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||||
}
|
}
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@@ -2267,6 +2267,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_lotus.py",
|
"nodes_lotus.py",
|
||||||
"nodes_hunyuan3d.py",
|
"nodes_hunyuan3d.py",
|
||||||
"nodes_primitive.py",
|
"nodes_primitive.py",
|
||||||
|
"nodes_cfg.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.14.5
|
comfyui-frontend-package==1.14.6
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
Reference in New Issue
Block a user