Compare commits
32 Commits
desktoprel
...
get-size-n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2cb81fd9e | ||
|
|
60e073e510 | ||
|
|
a9a7c9385a | ||
|
|
6d46bb4b4c | ||
|
|
67f57c5bcc | ||
|
|
fd943c928f | ||
|
|
d3bd983b91 | ||
|
|
fb4754624d | ||
|
|
180db6753f | ||
|
|
d062fcc5c0 | ||
|
|
456abad834 | ||
|
|
19e45e9b0e | ||
|
|
97f23b81f3 | ||
|
|
08b7cc7506 | ||
|
|
6c319cbb4e | ||
|
|
df1aebe52e | ||
|
|
704fc78854 | ||
|
|
1d9fee79fd | ||
|
|
aeba0b3a26 | ||
|
|
094306b626 | ||
|
|
31260f0275 | ||
|
|
f1c9ca816a | ||
|
|
f2289a1f59 | ||
|
|
fb83eda287 | ||
|
|
5e5e46d40c | ||
|
|
4eba3161cf | ||
|
|
592d056100 | ||
|
|
1c1687ab1c | ||
|
|
e6609dacde | ||
|
|
ba37e67964 | ||
|
|
06c661004e | ||
|
|
c9e1821a7b |
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -15,6 +15,14 @@ body:
|
|||||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||||
|
- type: checkboxes
|
||||||
|
id: custom-nodes-test
|
||||||
|
attributes:
|
||||||
|
label: Custom Node Testing
|
||||||
|
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||||
|
options:
|
||||||
|
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||||
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: Expected Behavior
|
label: Expected Behavior
|
||||||
|
|||||||
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@@ -11,6 +11,14 @@ body:
|
|||||||
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||||
|
- type: checkboxes
|
||||||
|
id: custom-nodes-test
|
||||||
|
attributes:
|
||||||
|
label: Custom Node Testing
|
||||||
|
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||||
|
options:
|
||||||
|
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||||
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: Your question
|
label: Your question
|
||||||
|
|||||||
26
CODEOWNERS
26
CODEOWNERS
@@ -5,20 +5,20 @@
|
|||||||
# Inlined the team members for now.
|
# Inlined the team members for now.
|
||||||
|
|
||||||
# Maintainers
|
# Maintainers
|
||||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
|
|
||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
|
|
||||||
# Node developers
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
|
|||||||
@@ -205,6 +205,19 @@ comfyui-workflow-templates is not installed.
|
|||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def embedded_docs_path(cls) -> str:
|
||||||
|
"""Get the path to embedded documentation"""
|
||||||
|
try:
|
||||||
|
import comfyui_embedded_docs
|
||||||
|
|
||||||
|
return str(
|
||||||
|
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logging.info("comfyui-embedded-docs package not found")
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ class CONDRegular:
|
|||||||
conds.append(x.cond)
|
conds.append(x.cond)
|
||||||
return torch.cat(conds)
|
return torch.cat(conds)
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return list(self.cond.size())
|
||||||
|
|
||||||
|
|
||||||
class CONDNoiseShape(CONDRegular):
|
class CONDNoiseShape(CONDRegular):
|
||||||
def process_cond(self, batch_size, device, area, **kwargs):
|
def process_cond(self, batch_size, device, area, **kwargs):
|
||||||
data = self.cond
|
data = self.cond
|
||||||
@@ -64,6 +68,7 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
out.append(c)
|
out.append(c)
|
||||||
return torch.cat(out)
|
return torch.cat(out)
|
||||||
|
|
||||||
|
|
||||||
class CONDConstant(CONDRegular):
|
class CONDConstant(CONDRegular):
|
||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
@@ -78,3 +83,48 @@ class CONDConstant(CONDRegular):
|
|||||||
|
|
||||||
def concat(self, others):
|
def concat(self, others):
|
||||||
return self.cond
|
return self.cond
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return [1]
|
||||||
|
|
||||||
|
|
||||||
|
class CONDList(CONDRegular):
|
||||||
|
def __init__(self, cond):
|
||||||
|
self.cond = cond
|
||||||
|
|
||||||
|
def process_cond(self, batch_size, device, **kwargs):
|
||||||
|
out = []
|
||||||
|
for c in self.cond:
|
||||||
|
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||||
|
|
||||||
|
return self._copy_with(out)
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
if len(self.cond) != len(other.cond):
|
||||||
|
return False
|
||||||
|
for i in range(len(self.cond)):
|
||||||
|
if self.cond[i].shape != other.cond[i].shape:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
out = []
|
||||||
|
for i in range(len(self.cond)):
|
||||||
|
o = [self.cond[i]]
|
||||||
|
for x in others:
|
||||||
|
o.append(x.cond[i])
|
||||||
|
out.append(torch.cat(o))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def size(self): # hackish implementation to make the mem estimation work
|
||||||
|
o = 0
|
||||||
|
c = 1
|
||||||
|
for c in self.cond:
|
||||||
|
size = c.size()
|
||||||
|
o += math.prod(size)
|
||||||
|
if len(size) > 1:
|
||||||
|
c = size[1]
|
||||||
|
|
||||||
|
return [1, c, o // c]
|
||||||
|
|||||||
@@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
@@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
@@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||||
mod = vec
|
mod = vec
|
||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
@@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x.addcmul_(mod.gate, output)
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
@@ -178,6 +176,6 @@ class LastLayer(nn.Module):
|
|||||||
shift, scale = vec
|
shift, scale = vec
|
||||||
shift = shift.squeeze(1)
|
shift = shift.squeeze(1)
|
||||||
scale = scale.squeeze(1)
|
scale = scale.squeeze(1)
|
||||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -539,13 +539,20 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||||
|
|
||||||
|
if time_dim_concat is not None:
|
||||||
|
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||||
|
x = torch.cat([x, time_dim_concat], dim=2)
|
||||||
|
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
|
||||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||||
|
|||||||
@@ -283,8 +283,9 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model."):
|
if k.startswith("diffusion_model."):
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||||
|
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.ACEStep):
|
if isinstance(model, comfy.model_base.ACEStep):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
|
|||||||
@@ -102,6 +102,13 @@ def model_sampling(model_config, model_type):
|
|||||||
return ModelSampling(model_config)
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tensor(extra, dtype):
|
||||||
|
if hasattr(extra, "dtype"):
|
||||||
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
|
extra = extra.to(dtype)
|
||||||
|
return extra
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -135,6 +142,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.info("model_type {}".format(model_type.name))
|
logging.info("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
|
self.memory_usage_factor_conds = ()
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
@@ -164,9 +172,14 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
|
|
||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
extra = convert_tensor(extra, dtype)
|
||||||
extra = extra.to(dtype)
|
elif isinstance(extra, list):
|
||||||
|
ex = []
|
||||||
|
for ext in extra:
|
||||||
|
ex.append(convert_tensor(ext, dtype))
|
||||||
|
extra = ex
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
t = self.process_timestep(t, x=x, **extra_conds)
|
t = self.process_timestep(t, x=x, **extra_conds)
|
||||||
@@ -325,19 +338,28 @@ class BaseModel(torch.nn.Module):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape, cond_shapes={}):
|
||||||
|
input_shapes = [input_shape]
|
||||||
|
for c in self.memory_usage_factor_conds:
|
||||||
|
shape = cond_shapes.get(c, None)
|
||||||
|
if shape is not None and len(shape) > 0:
|
||||||
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = input_shape[0] * math.prod(input_shape[2:])
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
else:
|
else:
|
||||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||||
area = input_shape[0] * math.prod(input_shape[2:])
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
||||||
adm_inputs = []
|
adm_inputs = []
|
||||||
@@ -1047,6 +1069,11 @@ class WAN21(BaseModel):
|
|||||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||||
if clip_vision_output is not None:
|
if clip_vision_output is not None:
|
||||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||||
|
|
||||||
|
time_dim_concat = kwargs.get("time_dim_concat", None)
|
||||||
|
if time_dim_concat is not None:
|
||||||
|
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -297,8 +297,13 @@ except:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
|
try:
|
||||||
|
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||||
|
except:
|
||||||
|
rocm_version = (6, -1)
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
logging.info("AMD arch: {}".format(arch))
|
logging.info("AMD arch: {}".format(arch))
|
||||||
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import uuid
|
import uuid
|
||||||
|
import math
|
||||||
|
import collections
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@@ -104,6 +106,21 @@ def cleanup_additional_models(models):
|
|||||||
if hasattr(m, 'cleanup'):
|
if hasattr(m, 'cleanup'):
|
||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
def estimate_memory(model, noise_shape, conds):
|
||||||
|
cond_shapes = collections.defaultdict(list)
|
||||||
|
cond_shapes_min = {}
|
||||||
|
for _, cs in conds.items():
|
||||||
|
for cond in cs:
|
||||||
|
for k, v in model.model.extra_conds_shapes(**cond).items():
|
||||||
|
cond_shapes[k].append(v)
|
||||||
|
if cond_shapes_min.get(k, None) is None:
|
||||||
|
cond_shapes_min[k] = [v]
|
||||||
|
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
|
||||||
|
cond_shapes_min[k] = [v]
|
||||||
|
|
||||||
|
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
|
||||||
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
@@ -117,9 +134,8 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
|||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
@@ -256,7 +256,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
for i in range(1, len(to_batch_temp) + 1):
|
for i in range(1, len(to_batch_temp) + 1):
|
||||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
cond_shapes = collections.defaultdict(list)
|
||||||
|
for tt in batch_amount:
|
||||||
|
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
|
||||||
|
for k, v in to_run[tt][0].conditioning.items():
|
||||||
|
cond_shapes[k].append(v.size())
|
||||||
|
|
||||||
|
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||||
to_batch = batch_amount
|
to_batch = batch_amount
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
{
|
|
||||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
|
||||||
"architectures": [
|
|
||||||
"CLIPTextModel"
|
|
||||||
],
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"bos_token_id": 0,
|
|
||||||
"dropout": 0.0,
|
|
||||||
"eos_token_id": 49407,
|
|
||||||
"hidden_act": "quick_gelu",
|
|
||||||
"hidden_size": 768,
|
|
||||||
"initializer_factor": 1.0,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 3072,
|
|
||||||
"layer_norm_eps": 1e-05,
|
|
||||||
"max_position_embeddings": 248,
|
|
||||||
"model_type": "clip_text_model",
|
|
||||||
"num_attention_heads": 12,
|
|
||||||
"num_hidden_layers": 12,
|
|
||||||
"pad_token_id": 1,
|
|
||||||
"projection_dim": 768,
|
|
||||||
"torch_dtype": "float32",
|
|
||||||
"transformers_version": "4.24.0",
|
|
||||||
"vocab_size": 49408
|
|
||||||
}
|
|
||||||
@@ -108,6 +108,24 @@ class BFLFluxProGenerateRequest(BaseModel):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||||
|
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||||
|
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||||
|
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
|
||||||
|
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
|
||||||
|
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
|
||||||
|
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||||
|
)
|
||||||
|
output_format: Optional[BFLOutputFormat] = Field(
|
||||||
|
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||||
|
)
|
||||||
|
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||||
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
|
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||||
prompt_upsampling: Optional[bool] = Field(
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
|
|||||||
@@ -327,7 +327,9 @@ class ApiClient:
|
|||||||
ApiServerError: If the API server is unreachable but internet is working
|
ApiServerError: If the API server is unreachable but internet is working
|
||||||
Exception: For other request failures
|
Exception: For other request failures
|
||||||
"""
|
"""
|
||||||
url = urljoin(self.base_url, path)
|
# Use urljoin but ensure path is relative to avoid absolute path behavior
|
||||||
|
relative_path = path.lstrip('/')
|
||||||
|
url = urljoin(self.base_url, relative_path)
|
||||||
self.check_auth(self.auth_token, self.comfy_api_key)
|
self.check_auth(self.auth_token, self.comfy_api_key)
|
||||||
# Combine default headers with any provided headers
|
# Combine default headers with any provided headers
|
||||||
request_headers = self.get_headers()
|
request_headers = self.get_headers()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
from typing import Union
|
from typing import Union, Optional
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||||
from comfy_api_nodes.apis.bfl_api import (
|
from comfy_api_nodes.apis.bfl_api import (
|
||||||
BFLStatus,
|
BFLStatus,
|
||||||
@@ -9,6 +9,7 @@ from comfy_api_nodes.apis.bfl_api import (
|
|||||||
BFLFluxCannyImageRequest,
|
BFLFluxCannyImageRequest,
|
||||||
BFLFluxDepthImageRequest,
|
BFLFluxDepthImageRequest,
|
||||||
BFLFluxProGenerateRequest,
|
BFLFluxProGenerateRequest,
|
||||||
|
BFLFluxKontextProGenerateRequest,
|
||||||
BFLFluxProUltraGenerateRequest,
|
BFLFluxProUltraGenerateRequest,
|
||||||
BFLFluxProGenerateResponse,
|
BFLFluxProGenerateResponse,
|
||||||
)
|
)
|
||||||
@@ -269,6 +270,158 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
|||||||
return (output_image,)
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxKontextProImageNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MINIMUM_RATIO = 1 / 4
|
||||||
|
MAXIMUM_RATIO = 4 / 1
|
||||||
|
MINIMUM_RATIO_STR = "1:4"
|
||||||
|
MAXIMUM_RATIO_STR = "4:1"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Prompt for the image generation - specify what and how to edit.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"aspect_ratio": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"default": "16:9",
|
||||||
|
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"guidance": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 3.0,
|
||||||
|
"min": 0.1,
|
||||||
|
"max": 99.0,
|
||||||
|
"step": 0.1,
|
||||||
|
"tooltip": "Guidance strength for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 50,
|
||||||
|
"min": 1,
|
||||||
|
"max": 150,
|
||||||
|
"tooltip": "Number of steps for the image generation process"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1234,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_upsampling": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"input_image": (IO.IMAGE,),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||||
|
try:
|
||||||
|
validate_aspect_ratio(
|
||||||
|
aspect_ratio,
|
||||||
|
minimum_ratio=cls.MINIMUM_RATIO,
|
||||||
|
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||||
|
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||||
|
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return str(e)
|
||||||
|
return True
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
|
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
guidance: float,
|
||||||
|
steps: int,
|
||||||
|
input_image: Optional[torch.Tensor]=None,
|
||||||
|
seed=0,
|
||||||
|
prompt_upsampling=False,
|
||||||
|
unique_id: Union[str, None] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if input_image is None:
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=self.BFL_PATH,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=BFLFluxKontextProGenerateRequest,
|
||||||
|
response_model=BFLFluxProGenerateResponse,
|
||||||
|
),
|
||||||
|
request=BFLFluxKontextProGenerateRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_upsampling=prompt_upsampling,
|
||||||
|
guidance=round(guidance, 1),
|
||||||
|
steps=steps,
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=validate_aspect_ratio(
|
||||||
|
aspect_ratio,
|
||||||
|
minimum_ratio=self.MINIMUM_RATIO,
|
||||||
|
maximum_ratio=self.MAXIMUM_RATIO,
|
||||||
|
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||||
|
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||||
|
),
|
||||||
|
input_image=(
|
||||||
|
input_image
|
||||||
|
if input_image is None
|
||||||
|
else convert_image_to_base64(input_image)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||||
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||||
|
"""
|
||||||
|
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||||
|
|
||||||
|
|
||||||
class FluxProImageNode(ComfyNodeABC):
|
class FluxProImageNode(ComfyNodeABC):
|
||||||
"""
|
"""
|
||||||
@@ -914,6 +1067,8 @@ class FluxProDepthNode(ComfyNodeABC):
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"FluxProUltraImageNode": FluxProUltraImageNode,
|
"FluxProUltraImageNode": FluxProUltraImageNode,
|
||||||
# "FluxProImageNode": FluxProImageNode,
|
# "FluxProImageNode": FluxProImageNode,
|
||||||
|
"FluxKontextProImageNode": FluxKontextProImageNode,
|
||||||
|
"FluxKontextMaxImageNode": FluxKontextMaxImageNode,
|
||||||
"FluxProExpandNode": FluxProExpandNode,
|
"FluxProExpandNode": FluxProExpandNode,
|
||||||
"FluxProFillNode": FluxProFillNode,
|
"FluxProFillNode": FluxProFillNode,
|
||||||
"FluxProCannyNode": FluxProCannyNode,
|
"FluxProCannyNode": FluxProCannyNode,
|
||||||
@@ -924,6 +1079,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
||||||
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
||||||
|
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
|
||||||
|
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
|
||||||
"FluxProExpandNode": "Flux.1 Expand Image",
|
"FluxProExpandNode": "Flux.1 Expand Image",
|
||||||
"FluxProFillNode": "Flux.1 Fill Image",
|
"FluxProFillNode": "Flux.1 Fill Image",
|
||||||
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
||||||
|
|||||||
@@ -6,40 +6,42 @@ Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
from typing import Optional, TypeVar
|
|
||||||
import logging
|
import logging
|
||||||
import torch
|
from typing import Optional, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
download_url_to_video_output,
|
||||||
|
tensor_to_bytesio,
|
||||||
|
)
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
|
||||||
PikaGenerateResponse,
|
|
||||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
|
||||||
PikaVideoResponse,
|
|
||||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
|
||||||
IngredientsMode,
|
IngredientsMode,
|
||||||
PikaDurationEnum,
|
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||||
PikaResolutionEnum,
|
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
|
||||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
|
||||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
|
||||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||||
|
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||||
|
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
|
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
|
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||||
|
PikaDurationEnum,
|
||||||
Pikaffect,
|
Pikaffect,
|
||||||
|
PikaGenerateResponse,
|
||||||
|
PikaResolutionEnum,
|
||||||
|
PikaVideoResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.apis.client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
EmptyRequest,
|
||||||
)
|
HttpMethod,
|
||||||
from comfy_api_nodes.apinode_utils import (
|
PollingOperation,
|
||||||
tensor_to_bytesio,
|
SynchronousOperation,
|
||||||
download_url_to_video_output,
|
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||||
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
|
||||||
|
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
@@ -204,6 +206,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
|||||||
"hidden": {
|
"hidden": {
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,7 +460,7 @@ class PikAdditionsNode(PikaNodeBase):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result."
|
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
|
||||||
|
|
||||||
def api_call(
|
def api_call(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -14,8 +14,10 @@ import re
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
import torch
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
from comfy.comfy_types import FileLocator
|
from comfy.comfy_types import FileLocator, IO
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||||
|
|
||||||
@@ -229,6 +231,186 @@ class SVG:
|
|||||||
all_svgs_list.extend(svg_item.data)
|
all_svgs_list.extend(svg_item.data)
|
||||||
return SVG(all_svgs_list)
|
return SVG(all_svgs_list)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageStitch:
|
||||||
|
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image1": ("IMAGE",),
|
||||||
|
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||||
|
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||||
|
"spacing_width": (
|
||||||
|
"INT",
|
||||||
|
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||||
|
),
|
||||||
|
"spacing_color": (
|
||||||
|
["white", "black", "red", "green", "blue"],
|
||||||
|
{"default": "white"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image2": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "stitch"
|
||||||
|
CATEGORY = "image/transform"
|
||||||
|
DESCRIPTION = """
|
||||||
|
Stitches image2 to image1 in the specified direction.
|
||||||
|
If image2 is not provided, returns image1 unchanged.
|
||||||
|
Optional spacing can be added between images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def stitch(
|
||||||
|
self,
|
||||||
|
image1,
|
||||||
|
direction,
|
||||||
|
match_image_size,
|
||||||
|
spacing_width,
|
||||||
|
spacing_color,
|
||||||
|
image2=None,
|
||||||
|
):
|
||||||
|
if image2 is None:
|
||||||
|
return (image1,)
|
||||||
|
|
||||||
|
# Handle batch size differences
|
||||||
|
if image1.shape[0] != image2.shape[0]:
|
||||||
|
max_batch = max(image1.shape[0], image2.shape[0])
|
||||||
|
if image1.shape[0] < max_batch:
|
||||||
|
image1 = torch.cat(
|
||||||
|
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
||||||
|
)
|
||||||
|
if image2.shape[0] < max_batch:
|
||||||
|
image2 = torch.cat(
|
||||||
|
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Match image sizes if requested
|
||||||
|
if match_image_size:
|
||||||
|
h1, w1 = image1.shape[1:3]
|
||||||
|
h2, w2 = image2.shape[1:3]
|
||||||
|
aspect_ratio = w2 / h2
|
||||||
|
|
||||||
|
if direction in ["left", "right"]:
|
||||||
|
target_h, target_w = h1, int(h1 * aspect_ratio)
|
||||||
|
else: # up, down
|
||||||
|
target_w, target_h = w1, int(w1 / aspect_ratio)
|
||||||
|
|
||||||
|
image2 = comfy.utils.common_upscale(
|
||||||
|
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
||||||
|
).movedim(1, -1)
|
||||||
|
|
||||||
|
# When not matching sizes, pad to align non-concat dimensions
|
||||||
|
if not match_image_size:
|
||||||
|
h1, w1 = image1.shape[1:3]
|
||||||
|
h2, w2 = image2.shape[1:3]
|
||||||
|
|
||||||
|
if direction in ["left", "right"]:
|
||||||
|
# For horizontal concat, pad heights to match
|
||||||
|
if h1 != h2:
|
||||||
|
target_h = max(h1, h2)
|
||||||
|
if h1 < target_h:
|
||||||
|
pad_h = target_h - h1
|
||||||
|
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||||
|
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||||
|
if h2 < target_h:
|
||||||
|
pad_h = target_h - h2
|
||||||
|
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||||
|
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||||
|
else: # up, down
|
||||||
|
# For vertical concat, pad widths to match
|
||||||
|
if w1 != w2:
|
||||||
|
target_w = max(w1, w2)
|
||||||
|
if w1 < target_w:
|
||||||
|
pad_w = target_w - w1
|
||||||
|
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||||
|
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||||
|
if w2 < target_w:
|
||||||
|
pad_w = target_w - w2
|
||||||
|
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||||
|
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||||
|
|
||||||
|
# Ensure same number of channels
|
||||||
|
if image1.shape[-1] != image2.shape[-1]:
|
||||||
|
max_channels = max(image1.shape[-1], image2.shape[-1])
|
||||||
|
if image1.shape[-1] < max_channels:
|
||||||
|
image1 = torch.cat(
|
||||||
|
[
|
||||||
|
image1,
|
||||||
|
torch.ones(
|
||||||
|
*image1.shape[:-1],
|
||||||
|
max_channels - image1.shape[-1],
|
||||||
|
device=image1.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
if image2.shape[-1] < max_channels:
|
||||||
|
image2 = torch.cat(
|
||||||
|
[
|
||||||
|
image2,
|
||||||
|
torch.ones(
|
||||||
|
*image2.shape[:-1],
|
||||||
|
max_channels - image2.shape[-1],
|
||||||
|
device=image2.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add spacing if specified
|
||||||
|
if spacing_width > 0:
|
||||||
|
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
||||||
|
|
||||||
|
color_map = {
|
||||||
|
"white": 1.0,
|
||||||
|
"black": 0.0,
|
||||||
|
"red": (1.0, 0.0, 0.0),
|
||||||
|
"green": (0.0, 1.0, 0.0),
|
||||||
|
"blue": (0.0, 0.0, 1.0),
|
||||||
|
}
|
||||||
|
color_val = color_map[spacing_color]
|
||||||
|
|
||||||
|
if direction in ["left", "right"]:
|
||||||
|
spacing_shape = (
|
||||||
|
image1.shape[0],
|
||||||
|
max(image1.shape[1], image2.shape[1]),
|
||||||
|
spacing_width,
|
||||||
|
image1.shape[-1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
spacing_shape = (
|
||||||
|
image1.shape[0],
|
||||||
|
spacing_width,
|
||||||
|
max(image1.shape[2], image2.shape[2]),
|
||||||
|
image1.shape[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
||||||
|
if isinstance(color_val, tuple):
|
||||||
|
for i, c in enumerate(color_val):
|
||||||
|
if i < spacing.shape[-1]:
|
||||||
|
spacing[..., i] = c
|
||||||
|
if spacing.shape[-1] == 4: # Add alpha
|
||||||
|
spacing[..., 3] = 1.0
|
||||||
|
else:
|
||||||
|
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
||||||
|
if spacing.shape[-1] == 4:
|
||||||
|
spacing[..., 3] = 1.0
|
||||||
|
|
||||||
|
# Concatenate images
|
||||||
|
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
||||||
|
if spacing_width > 0:
|
||||||
|
images.insert(1, spacing)
|
||||||
|
|
||||||
|
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||||
|
return (torch.cat(images, dim=concat_dim),)
|
||||||
|
|
||||||
|
|
||||||
class SaveSVGNode:
|
class SaveSVGNode:
|
||||||
"""
|
"""
|
||||||
Save SVG files on disk.
|
Save SVG files on disk.
|
||||||
@@ -310,6 +492,36 @@ class SaveSVGNode:
|
|||||||
counter += 1
|
counter += 1
|
||||||
return { "ui": { "images": results } }
|
return { "ui": { "images": results } }
|
||||||
|
|
||||||
|
class GetImageSize:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.INT, IO.INT)
|
||||||
|
RETURN_NAMES = ("width", "height")
|
||||||
|
FUNCTION = "get_size"
|
||||||
|
|
||||||
|
CATEGORY = "image"
|
||||||
|
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
|
||||||
|
|
||||||
|
def get_size(self, image, unique_id=None) -> tuple[int, int]:
|
||||||
|
height = image.shape[1]
|
||||||
|
width = image.shape[2]
|
||||||
|
|
||||||
|
# Send progress text to display size on the node
|
||||||
|
if unique_id:
|
||||||
|
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}", unique_id)
|
||||||
|
|
||||||
|
return width, height
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageCrop": ImageCrop,
|
"ImageCrop": ImageCrop,
|
||||||
"RepeatImageBatch": RepeatImageBatch,
|
"RepeatImageBatch": RepeatImageBatch,
|
||||||
@@ -318,4 +530,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||||
"SaveSVGNode": SaveSVGNode,
|
"SaveSVGNode": SaveSVGNode,
|
||||||
|
"ImageStitch": ImageStitch,
|
||||||
|
"GetImageSize": GetImageSize,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -296,6 +296,41 @@ class RegexExtract():
|
|||||||
|
|
||||||
return result,
|
return result,
|
||||||
|
|
||||||
|
|
||||||
|
class RegexReplace():
|
||||||
|
DESCRIPTION = "Find and replace text using regex patterns."
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"string": (IO.STRING, {"multiline": True}),
|
||||||
|
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||||
|
"replace": (IO.STRING, {"multiline": True}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||||
|
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||||
|
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
||||||
|
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.STRING,)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "utils/string"
|
||||||
|
|
||||||
|
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
||||||
|
flags = 0
|
||||||
|
|
||||||
|
if case_insensitive:
|
||||||
|
flags |= re.IGNORECASE
|
||||||
|
if multiline:
|
||||||
|
flags |= re.MULTILINE
|
||||||
|
if dotall:
|
||||||
|
flags |= re.DOTALL
|
||||||
|
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||||
|
return result,
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"StringConcatenate": StringConcatenate,
|
"StringConcatenate": StringConcatenate,
|
||||||
"StringSubstring": StringSubstring,
|
"StringSubstring": StringSubstring,
|
||||||
@@ -306,7 +341,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"StringContains": StringContains,
|
"StringContains": StringContains,
|
||||||
"StringCompare": StringCompare,
|
"StringCompare": StringCompare,
|
||||||
"RegexMatch": RegexMatch,
|
"RegexMatch": RegexMatch,
|
||||||
"RegexExtract": RegexExtract
|
"RegexExtract": RegexExtract,
|
||||||
|
"RegexReplace": RegexReplace,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@@ -319,5 +355,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"StringContains": "Contains",
|
"StringContains": "Contains",
|
||||||
"StringCompare": "Compare",
|
"StringCompare": "Compare",
|
||||||
"RegexMatch": "Regex Match",
|
"RegexMatch": "Regex Match",
|
||||||
"RegexExtract": "Regex Extract"
|
"RegexExtract": "Regex Extract",
|
||||||
|
"RegexReplace": "Regex Replace",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -345,6 +345,44 @@ class WanCameraImageToVideo:
|
|||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
class WanPhantomSubjectToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"images": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
cond2 = negative
|
||||||
|
if images is not None:
|
||||||
|
images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
latent_images = []
|
||||||
|
for i in images:
|
||||||
|
latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])]
|
||||||
|
concat_latent_image = torch.cat(latent_images, dim=2)
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image})
|
||||||
|
cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, cond2, negative, out_latent)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
@@ -353,4 +391,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"WanVaceToVideo": WanVaceToVideo,
|
"WanVaceToVideo": WanVaceToVideo,
|
||||||
"TrimVideoLatent": TrimVideoLatent,
|
"TrimVideoLatent": TrimVideoLatent,
|
||||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||||
|
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.36"
|
__version__ = "0.3.39"
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@@ -2061,11 +2061,13 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
"ImageBatch": "Batch Images",
|
"ImageBatch": "Batch Images",
|
||||||
"ImageCrop": "Image Crop",
|
"ImageCrop": "Image Crop",
|
||||||
|
"ImageStitch": "Image Stitch",
|
||||||
"ImageBlend": "Image Blend",
|
"ImageBlend": "Image Blend",
|
||||||
"ImageBlur": "Image Blur",
|
"ImageBlur": "Image Blur",
|
||||||
"ImageQuantize": "Image Quantize",
|
"ImageQuantize": "Image Quantize",
|
||||||
"ImageSharpen": "Image Sharpen",
|
"ImageSharpen": "Image Sharpen",
|
||||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||||
|
"GetImageSize": "Get Image Size",
|
||||||
# _for_testing
|
# _for_testing
|
||||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.36"
|
version = "0.3.39"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
comfyui-frontend-package==1.20.6
|
comfyui-frontend-package==1.21.5
|
||||||
comfyui-workflow-templates==0.1.20
|
comfyui-workflow-templates==0.1.25
|
||||||
|
comfyui-embedded-docs==0.2.0
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@@ -746,6 +746,13 @@ class PromptServer():
|
|||||||
web.static('/templates', workflow_templates_path)
|
web.static('/templates', workflow_templates_path)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# Serve embedded documentation from the package
|
||||||
|
embedded_docs_path = FrontendManager.embedded_docs_path()
|
||||||
|
if embedded_docs_path:
|
||||||
|
self.app.add_routes([
|
||||||
|
web.static('/docs', embedded_docs_path)
|
||||||
|
])
|
||||||
|
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root),
|
||||||
])
|
])
|
||||||
|
|||||||
0
tests-unit/comfy_extras_test/__init__.py
Normal file
0
tests-unit/comfy_extras_test/__init__.py
Normal file
243
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
243
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
import torch
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
# Mock nodes module to prevent CUDA initialization during import
|
||||||
|
mock_nodes = MagicMock()
|
||||||
|
mock_nodes.MAX_RESOLUTION = 16384
|
||||||
|
|
||||||
|
# Mock server module for PromptServer
|
||||||
|
mock_server = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}):
|
||||||
|
from comfy_extras.nodes_images import ImageStitch
|
||||||
|
|
||||||
|
|
||||||
|
class TestImageStitch:
|
||||||
|
|
||||||
|
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
|
||||||
|
"""Helper to create test images with specific dimensions"""
|
||||||
|
return torch.rand(batch_size, height, width, channels)
|
||||||
|
|
||||||
|
def test_no_image2_passthrough(self):
|
||||||
|
"""Test that when image2 is None, image1 is returned unchanged"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image()
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert torch.equal(result[0], image1)
|
||||||
|
|
||||||
|
def test_basic_horizontal_stitch_right(self):
|
||||||
|
"""Test basic horizontal stitching to the right"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=24)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
|
||||||
|
|
||||||
|
def test_basic_horizontal_stitch_left(self):
|
||||||
|
"""Test basic horizontal stitching to the left"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=24)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "left", False, 0, "white", image2)
|
||||||
|
|
||||||
|
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
|
||||||
|
|
||||||
|
def test_basic_vertical_stitch_down(self):
|
||||||
|
"""Test basic vertical stitching downward"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=24, width=32)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||||
|
|
||||||
|
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
|
||||||
|
|
||||||
|
def test_basic_vertical_stitch_up(self):
|
||||||
|
"""Test basic vertical stitching upward"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=24, width=32)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "up", False, 0, "white", image2)
|
||||||
|
|
||||||
|
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
|
||||||
|
|
||||||
|
def test_size_matching_horizontal(self):
|
||||||
|
"""Test size matching for horizontal concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=64, width=64)
|
||||||
|
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", True, 0, "white", image2)
|
||||||
|
|
||||||
|
# image2 should be resized to match image1's height (64) with preserved aspect ratio
|
||||||
|
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
|
||||||
|
assert result[0].shape == (1, 64, expected_width, 3)
|
||||||
|
|
||||||
|
def test_size_matching_vertical(self):
|
||||||
|
"""Test size matching for vertical concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=64, width=64)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "down", True, 0, "white", image2)
|
||||||
|
|
||||||
|
# image2 should be resized to match image1's width (64) with preserved aspect ratio
|
||||||
|
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
|
||||||
|
assert result[0].shape == (1, expected_height, 64, 3)
|
||||||
|
|
||||||
|
def test_padding_for_mismatched_heights_horizontal(self):
|
||||||
|
"""Test padding when heights don't match in horizontal concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=64, width=32)
|
||||||
|
image2 = self.create_test_image(height=48, width=24) # Shorter height
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Both images should be padded to height 64
|
||||||
|
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
|
||||||
|
|
||||||
|
def test_padding_for_mismatched_widths_vertical(self):
|
||||||
|
"""Test padding when widths don't match in vertical concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=64)
|
||||||
|
image2 = self.create_test_image(height=24, width=48) # Narrower width
|
||||||
|
|
||||||
|
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Both images should be padded to width 64
|
||||||
|
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
|
||||||
|
|
||||||
|
def test_spacing_horizontal(self):
|
||||||
|
"""Test spacing addition in horizontal concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=24)
|
||||||
|
spacing_width = 16
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
|
||||||
|
|
||||||
|
# Expected width: 32 + 16 (spacing) + 24 = 72
|
||||||
|
assert result[0].shape == (1, 32, 72, 3)
|
||||||
|
|
||||||
|
def test_spacing_vertical(self):
|
||||||
|
"""Test spacing addition in vertical concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=24, width=32)
|
||||||
|
spacing_width = 16
|
||||||
|
|
||||||
|
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
|
||||||
|
|
||||||
|
# Expected height: 32 + 16 (spacing) + 24 = 72
|
||||||
|
assert result[0].shape == (1, 72, 32, 3)
|
||||||
|
|
||||||
|
def test_spacing_color_values(self):
|
||||||
|
"""Test that spacing colors are applied correctly"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
# Test white spacing
|
||||||
|
result_white = node.stitch(image1, "right", False, 16, "white", image2)
|
||||||
|
# Check that spacing region contains white values (close to 1.0)
|
||||||
|
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
|
||||||
|
assert torch.all(spacing_region >= 0.9) # Should be close to white
|
||||||
|
|
||||||
|
# Test black spacing
|
||||||
|
result_black = node.stitch(image1, "right", False, 16, "black", image2)
|
||||||
|
spacing_region = result_black[0][:, :, 32:48, :]
|
||||||
|
assert torch.all(spacing_region <= 0.1) # Should be close to black
|
||||||
|
|
||||||
|
def test_odd_spacing_width_made_even(self):
|
||||||
|
"""Test that odd spacing widths are made even"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
# Use odd spacing width
|
||||||
|
result = node.stitch(image1, "right", False, 15, "white", image2)
|
||||||
|
|
||||||
|
# Should be made even (16), so total width = 32 + 16 + 32 = 80
|
||||||
|
assert result[0].shape == (1, 32, 80, 3)
|
||||||
|
|
||||||
|
def test_batch_size_matching(self):
|
||||||
|
"""Test that different batch sizes are handled correctly"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(batch_size=2, height=32, width=32)
|
||||||
|
image2 = self.create_test_image(batch_size=1, height=32, width=32)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Should match larger batch size
|
||||||
|
assert result[0].shape == (2, 32, 64, 3)
|
||||||
|
|
||||||
|
def test_channel_matching_rgb_to_rgba(self):
|
||||||
|
"""Test that channel differences are handled (RGB + alpha)"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(channels=3) # RGB
|
||||||
|
image2 = self.create_test_image(channels=4) # RGBA
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Should have 4 channels (RGBA)
|
||||||
|
assert result[0].shape[-1] == 4
|
||||||
|
|
||||||
|
def test_channel_matching_rgba_to_rgb(self):
|
||||||
|
"""Test that channel differences are handled (RGBA + RGB)"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(channels=4) # RGBA
|
||||||
|
image2 = self.create_test_image(channels=3) # RGB
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Should have 4 channels (RGBA)
|
||||||
|
assert result[0].shape[-1] == 4
|
||||||
|
|
||||||
|
def test_all_color_options(self):
|
||||||
|
"""Test all available color options"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
colors = ["white", "black", "red", "green", "blue"]
|
||||||
|
|
||||||
|
for color in colors:
|
||||||
|
result = node.stitch(image1, "right", False, 16, color, image2)
|
||||||
|
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
|
||||||
|
|
||||||
|
def test_all_directions(self):
|
||||||
|
"""Test all direction options"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
directions = ["right", "left", "up", "down"]
|
||||||
|
|
||||||
|
for direction in directions:
|
||||||
|
result = node.stitch(image1, direction, False, 0, "white", image2)
|
||||||
|
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
|
||||||
|
|
||||||
|
def test_batch_size_channel_spacing_integration(self):
|
||||||
|
"""Test integration of batch matching, channel matching, size matching, and spacings"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
|
||||||
|
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", True, 8, "red", image2)
|
||||||
|
|
||||||
|
# Should handle: batch matching, size matching, channel matching, spacing
|
||||||
|
assert result[0].shape[0] == 2 # Batch size matched
|
||||||
|
assert result[0].shape[-1] == 4 # Channels matched to max
|
||||||
|
assert result[0].shape[1] == 64 # Height from image1 (size matching)
|
||||||
|
# Width should be: 48 + 8 (spacing) + resized_image2_width
|
||||||
|
expected_image2_width = int(64 * (32/32)) # Resized to height 64
|
||||||
|
expected_total_width = 48 + 8 + expected_image2_width
|
||||||
|
assert result[0].shape[2] == expected_total_width
|
||||||
|
|
||||||
Reference in New Issue
Block a user