Compare commits
11 Commits
not_requir
...
v0.3.29
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93292bc450 | ||
|
|
05d5a75cdc | ||
|
|
eba7a25e7a | ||
|
|
dbcfd092a2 | ||
|
|
c14429940f | ||
|
|
0d720e4367 | ||
|
|
1fc00ba4b6 | ||
|
|
9899d187b1 | ||
|
|
f00f340a56 | ||
|
|
cce1d9145e | ||
|
|
b4dc03ad76 |
@@ -184,6 +184,27 @@ comfyui-frontend-package is not installed.
|
|||||||
)
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def templates_path(cls) -> str:
|
||||||
|
try:
|
||||||
|
import comfyui_workflow_templates
|
||||||
|
|
||||||
|
return str(
|
||||||
|
importlib.resources.files(comfyui_workflow_templates) / "templates"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logging.error(
|
||||||
|
f"""
|
||||||
|
********** ERROR ***********
|
||||||
|
|
||||||
|
comfyui-workflow-templates is not installed.
|
||||||
|
|
||||||
|
{frontend_install_warning_message()}
|
||||||
|
|
||||||
|
********** ERROR ***********
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -99,59 +99,59 @@ class InputTypeOptions(TypedDict):
|
|||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default: bool | str | float | int | list | tuple
|
default: NotRequired[bool | str | float | int | list | tuple]
|
||||||
"""The default value of the widget"""
|
"""The default value of the widget"""
|
||||||
defaultInput: bool
|
defaultInput: NotRequired[bool]
|
||||||
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
||||||
- defaultInput on required inputs should be dropped.
|
- defaultInput on required inputs should be dropped.
|
||||||
- defaultInput on optional inputs should be replaced with forceInput.
|
- defaultInput on optional inputs should be replaced with forceInput.
|
||||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
||||||
"""
|
"""
|
||||||
forceInput: bool
|
forceInput: NotRequired[bool]
|
||||||
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
||||||
lazy: bool
|
lazy: NotRequired[bool]
|
||||||
"""Declares that this input uses lazy evaluation"""
|
"""Declares that this input uses lazy evaluation"""
|
||||||
rawLink: bool
|
rawLink: NotRequired[bool]
|
||||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||||
tooltip: str
|
tooltip: NotRequired[str]
|
||||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||||
# class InputTypeNumber(InputTypeOptions):
|
# class InputTypeNumber(InputTypeOptions):
|
||||||
# default: float | int
|
# default: float | int
|
||||||
min: float
|
min: NotRequired[float]
|
||||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||||
max: float
|
max: NotRequired[float]
|
||||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||||
step: float
|
step: NotRequired[float]
|
||||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||||
round: float
|
round: NotRequired[float]
|
||||||
"""Floats are rounded by this value (``FLOAT``)"""
|
"""Floats are rounded by this value (``FLOAT``)"""
|
||||||
# class InputTypeBoolean(InputTypeOptions):
|
# class InputTypeBoolean(InputTypeOptions):
|
||||||
# default: bool
|
# default: bool
|
||||||
label_on: str
|
label_on: NotRequired[str]
|
||||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||||
label_off: str
|
label_off: NotRequired[str]
|
||||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||||
# class InputTypeString(InputTypeOptions):
|
# class InputTypeString(InputTypeOptions):
|
||||||
# default: str
|
# default: str
|
||||||
multiline: bool
|
multiline: NotRequired[bool]
|
||||||
"""Use a multiline text box (``STRING``)"""
|
"""Use a multiline text box (``STRING``)"""
|
||||||
placeholder: str
|
placeholder: NotRequired[str]
|
||||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||||
# Deprecated:
|
# Deprecated:
|
||||||
# defaultVal: str
|
# defaultVal: str
|
||||||
dynamicPrompts: bool
|
dynamicPrompts: NotRequired[bool]
|
||||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||||
# class InputTypeCombo(InputTypeOptions):
|
# class InputTypeCombo(InputTypeOptions):
|
||||||
image_upload: bool
|
image_upload: NotRequired[bool]
|
||||||
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
|
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
|
||||||
image_folder: Literal["input", "output", "temp"]
|
image_folder: NotRequired[Literal["input", "output", "temp"]]
|
||||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||||
"""
|
"""
|
||||||
remote: RemoteInputOptions
|
remote: NotRequired[RemoteInputOptions]
|
||||||
"""Specifies the configuration for a remote input.
|
"""Specifies the configuration for a remote input.
|
||||||
Available after ComfyUI frontend v1.9.7
|
Available after ComfyUI frontend v1.9.7
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||||
control_after_generate: bool
|
control_after_generate: NotRequired[bool]
|
||||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||||
options: NotRequired[list[str | int | float]]
|
options: NotRequired[list[str | int | float]]
|
||||||
"""COMBO type only. Specifies the selectable options for the combo widget.
|
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||||
@@ -169,15 +169,15 @@ class InputTypeOptions(TypedDict):
|
|||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||||
|
|
||||||
node_id: Literal["UNIQUE_ID"]
|
node_id: NotRequired[Literal["UNIQUE_ID"]]
|
||||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||||
unique_id: Literal["UNIQUE_ID"]
|
unique_id: NotRequired[Literal["UNIQUE_ID"]]
|
||||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||||
prompt: Literal["PROMPT"]
|
prompt: NotRequired[Literal["PROMPT"]]
|
||||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
|
||||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||||
dynprompt: Literal["DYNPROMPT"]
|
dynprompt: NotRequired[Literal["DYNPROMPT"]]
|
||||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||||
|
|
||||||
|
|
||||||
@@ -187,11 +187,11 @@ class InputTypeDict(TypedDict):
|
|||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
||||||
"""Describes all inputs that must be connected for the node to execute."""
|
"""Describes all inputs that must be connected for the node to execute."""
|
||||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
||||||
"""Describes inputs which do not need to be connected."""
|
"""Describes inputs which do not need to be connected."""
|
||||||
hidden: HiddenInputTypeDict
|
hidden: NotRequired[HiddenInputTypeDict]
|
||||||
"""Offers advanced functionality and server-client communication.
|
"""Offers advanced functionality and server-client communication.
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
||||||
|
|||||||
@@ -8,25 +8,12 @@ from einops import repeat
|
|||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope, rope
|
||||||
|
from comfy.ldm.flux.layers import LastLayer
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.ldm.common_dit
|
||||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
|
||||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
|
||||||
assert dim % 2 == 0, "The dimension must be even."
|
|
||||||
|
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
|
||||||
omega = 1.0 / (theta**scale)
|
|
||||||
|
|
||||||
batch_size, seq_length = pos.shape
|
|
||||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
|
||||||
cos_out = torch.cos(out)
|
|
||||||
sin_out = torch.sin(out)
|
|
||||||
|
|
||||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
|
||||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
|
||||||
return out.float()
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||||
@@ -84,23 +71,6 @@ class TimestepEmbed(nn.Module):
|
|||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
class OutEmbed(nn.Module):
|
|
||||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, adaln_input):
|
|
||||||
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
|
|
||||||
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||||
|
|
||||||
@@ -663,7 +633,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
||||||
caption_projection = []
|
caption_projection = []
|
||||||
@@ -732,7 +702,8 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
control = None,
|
control = None,
|
||||||
transformer_options = {},
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = x
|
bs, c, h, w = x.shape
|
||||||
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
timesteps = t
|
timesteps = t
|
||||||
pooled_embeds = y
|
pooled_embeds = y
|
||||||
T5_encoder_hidden_states = context
|
T5_encoder_hidden_states = context
|
||||||
@@ -825,4 +796,4 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||||
output = self.final_layer(hidden_states, adaln_input)
|
output = self.final_layer(hidden_states, adaln_input)
|
||||||
output = self.unpatchify(output, img_sizes)
|
output = self.unpatchify(output, img_sizes)
|
||||||
return -output
|
return -output[:, :, :h, :w]
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
|
|
||||||
def forward(self, x, context):
|
def forward(self, x, context, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context):
|
def forward(self, x, context, context_img_len):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
context(Tensor): Shape [B, L2, C]
|
context(Tensor): Shape [B, L2, C]
|
||||||
"""
|
"""
|
||||||
context_img = context[:, :257]
|
context_img = context[:, :context_img_len]
|
||||||
context = context[:, 257:]
|
context = context[:, context_img_len:]
|
||||||
|
|
||||||
# compute query, key, value
|
# compute query, key, value
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
@@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
e,
|
e,
|
||||||
freqs,
|
freqs,
|
||||||
context,
|
context,
|
||||||
|
context_img_len=257,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
x = x + y * e[2]
|
x = x + y * e[2]
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
||||||
x = x + y * e[5]
|
x = x + y * e[5]
|
||||||
return x
|
return x
|
||||||
@@ -250,7 +251,7 @@ class Head(nn.Module):
|
|||||||
|
|
||||||
class MLPProj(torch.nn.Module):
|
class MLPProj(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, operation_settings={}):
|
def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
@@ -258,7 +259,15 @@ class MLPProj(torch.nn.Module):
|
|||||||
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||||
|
|
||||||
|
if flf_pos_embed_token_number is not None:
|
||||||
|
self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||||
|
else:
|
||||||
|
self.emb_pos = None
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
|
if self.emb_pos is not None:
|
||||||
|
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
|
||||||
|
|
||||||
clip_extra_context_tokens = self.proj(image_embeds)
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
@@ -284,6 +293,7 @@ class WanModel(torch.nn.Module):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
|
flf_pos_embed_token_number=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@@ -373,7 +383,7 @@ class WanModel(torch.nn.Module):
|
|||||||
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
||||||
|
|
||||||
if model_type == 'i2v':
|
if model_type == 'i2v':
|
||||||
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
|
self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
|
||||||
else:
|
else:
|
||||||
self.img_emb = None
|
self.img_emb = None
|
||||||
|
|
||||||
@@ -420,9 +430,12 @@ class WanModel(torch.nn.Module):
|
|||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
if clip_fea is not None and self.img_emb is not None:
|
context_img_len = None
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
if clip_fea is not None:
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@@ -430,12 +443,12 @@ class WanModel(torch.nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context)
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
|||||||
@@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
else:
|
else:
|
||||||
dit_config["model_type"] = "t2v"
|
dit_config["model_type"] = "t2v"
|
||||||
|
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||||
|
if flf_weight is not None:
|
||||||
|
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ if RMSNorm is None:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return rms_norm(x, self.weight, self.eps)
|
return rms_norm(x, self.weight, self.eps)
|
||||||
|
|||||||
@@ -11,14 +11,15 @@ class HiDreamTokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data)
|
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
|
||||||
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||||
|
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
||||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,9 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
import comfy.clip_vision
|
||||||
|
|
||||||
|
|
||||||
class WanImageToVideo:
|
class WanImageToVideo:
|
||||||
@@ -99,6 +100,72 @@ class WanFunControlToVideo:
|
|||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
class WanFirstLastFrameToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"start_image": ("IMAGE", ),
|
||||||
|
"end_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if end_image is not None:
|
||||||
|
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|
||||||
|
image = torch.ones((length, height, width, 3)) * 0.5
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
image[:start_image.shape[0]] = start_image
|
||||||
|
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||||
|
|
||||||
|
if end_image is not None:
|
||||||
|
image[-end_image.shape[0]:] = end_image
|
||||||
|
mask[:, :, -end_image.shape[0]:] = 0.0
|
||||||
|
|
||||||
|
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||||
|
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
|
||||||
|
if clip_vision_start_image is not None:
|
||||||
|
clip_vision_output = clip_vision_start_image
|
||||||
|
|
||||||
|
if clip_vision_end_image is not None:
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2)
|
||||||
|
clip_vision_output = comfy.clip_vision.Output()
|
||||||
|
clip_vision_output.penultimate_hidden_states = states
|
||||||
|
else:
|
||||||
|
clip_vision_output = clip_vision_end_image
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanFunInpaintToVideo:
|
class WanFunInpaintToVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -122,38 +189,13 @@ class WanFunInpaintToVideo:
|
|||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
flfv = WanFirstLastFrameToVideo()
|
||||||
if start_image is not None:
|
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
if end_image is not None:
|
|
||||||
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
|
|
||||||
image = torch.ones((length, height, width, 3)) * 0.5
|
|
||||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
|
||||||
|
|
||||||
if start_image is not None:
|
|
||||||
image[:start_image.shape[0]] = start_image
|
|
||||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
|
||||||
|
|
||||||
if end_image is not None:
|
|
||||||
image[-end_image.shape[0]:] = end_image
|
|
||||||
mask[:, :, -end_image.shape[0]:] = 0.0
|
|
||||||
|
|
||||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
|
||||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
|
|
||||||
if clip_vision_output is not None:
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
|
||||||
|
|
||||||
out_latent = {}
|
|
||||||
out_latent["samples"] = latent
|
|
||||||
return (positive, negative, out_latent)
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||||
|
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.28"
|
__version__ = "0.3.29"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.28"
|
version = "0.3.29"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
comfyui-frontend-package==1.15.13
|
comfyui-frontend-package==1.16.8
|
||||||
|
comfyui-workflow-templates==0.1.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@@ -736,6 +736,12 @@ class PromptServer():
|
|||||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||||
|
|
||||||
|
workflow_templates_path = FrontendManager.templates_path()
|
||||||
|
if workflow_templates_path:
|
||||||
|
self.app.add_routes([
|
||||||
|
web.static('/templates', workflow_templates_path)
|
||||||
|
])
|
||||||
|
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root),
|
||||||
])
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user