Compare commits
3 Commits
v3-definit
...
openapi-sp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65ad9940b | ||
|
|
e8a92e4c9b | ||
|
|
fa9688b1fb |
49
.github/workflows/openapi-validation.yml
vendored
Normal file
49
.github/workflows/openapi-validation.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Validate OpenAPI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
paths:
|
||||
- 'openapi.yaml'
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
paths:
|
||||
- 'openapi.yaml'
|
||||
|
||||
jobs:
|
||||
openapi-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# Service containers to run with `runner-job`
|
||||
services:
|
||||
# Label used to access the service container
|
||||
swagger-editor:
|
||||
# Docker Hub image
|
||||
image: swaggerapi/swagger-editor
|
||||
ports:
|
||||
# Maps port 8080 on service container to the host 80
|
||||
- 80:8080
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Validate OpenAPI definition
|
||||
uses: swaggerexpert/swagger-editor-validate@v1
|
||||
with:
|
||||
definition-file: openapi.yaml
|
||||
swagger-editor-url: http://localhost/
|
||||
default-timeout: 20000
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
pip install -r tests-api/requirements.txt
|
||||
|
||||
- name: Run OpenAPI spec validation tests
|
||||
run: |
|
||||
pytest tests-api/test_spec_validation.py -v
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,6 +21,5 @@ venv/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
|
||||
26
CODEOWNERS
26
CODEOWNERS
@@ -5,20 +5,20 @@
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
|
||||
@@ -205,19 +205,6 @@ comfyui-workflow-templates is not installed.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
"""Get the path to embedded documentation"""
|
||||
try:
|
||||
import comfyui_embedded_docs
|
||||
|
||||
return str(
|
||||
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
||||
)
|
||||
except ImportError:
|
||||
logging.info("comfyui-embedded-docs package not found")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
|
||||
@@ -88,7 +88,6 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
|
||||
@@ -24,10 +24,6 @@ class CONDRegular:
|
||||
conds.append(x.cond)
|
||||
return torch.cat(conds)
|
||||
|
||||
def size(self):
|
||||
return list(self.cond.size())
|
||||
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond
|
||||
@@ -68,7 +64,6 @@ class CONDCrossAttn(CONDRegular):
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
||||
|
||||
|
||||
class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -83,48 +78,3 @@ class CONDConstant(CONDRegular):
|
||||
|
||||
def concat(self, others):
|
||||
return self.cond
|
||||
|
||||
def size(self):
|
||||
return [1]
|
||||
|
||||
|
||||
class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
def can_concat(self, other):
|
||||
if len(self.cond) != len(other.cond):
|
||||
return False
|
||||
for i in range(len(self.cond)):
|
||||
if self.cond[i].shape != other.cond[i].shape:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
out = []
|
||||
for i in range(len(self.cond)):
|
||||
o = [self.cond[i]]
|
||||
for x in others:
|
||||
o.append(x.cond[i])
|
||||
out.append(torch.cat(o))
|
||||
|
||||
return out
|
||||
|
||||
def size(self): # hackish implementation to make the mem estimation work
|
||||
o = 0
|
||||
c = 1
|
||||
for c in self.cond:
|
||||
size = c.size()
|
||||
o += math.prod(size)
|
||||
if len(size) > 1:
|
||||
c = size[1]
|
||||
|
||||
return [1, c, o // c]
|
||||
|
||||
@@ -80,13 +80,15 @@ class DoubleStreamBlock(nn.Module):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
@@ -100,12 +102,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -150,7 +152,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -160,7 +162,7 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@@ -176,6 +178,6 @@ class LastLayer(nn.Module):
|
||||
shift, scale = vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -163,7 +163,7 @@ class Chroma(nn.Module):
|
||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
||||
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
|
||||
@@ -20,11 +20,8 @@ if model_management.xformers_enabled():
|
||||
if model_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == "sageattention":
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
else:
|
||||
raise e
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
exit(-1)
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
|
||||
@@ -539,20 +539,13 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
@@ -642,7 +635,7 @@ class VaceWanModel(WanModel):
|
||||
t,
|
||||
context,
|
||||
vace_context,
|
||||
vace_strength,
|
||||
vace_strength=1.0,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
@@ -668,11 +661,8 @@ class VaceWanModel(WanModel):
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
orig_shape = list(vace_context.shape)
|
||||
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
|
||||
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||
c = c.flatten(2).transpose(1, 2)
|
||||
c = list(c.split(orig_shape[0], dim=0))
|
||||
|
||||
# arguments
|
||||
x_orig = x
|
||||
@@ -692,9 +682,8 @@ class VaceWanModel(WanModel):
|
||||
|
||||
ii = self.vace_layers_mapping.get(i, None)
|
||||
if ii is not None:
|
||||
for iii in range(len(c)):
|
||||
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x += c_skip * vace_strength[iii]
|
||||
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x += c_skip * vace_strength
|
||||
del c_skip
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
@@ -283,9 +283,8 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep):
|
||||
for k in sdk:
|
||||
|
||||
@@ -135,7 +135,6 @@ class BaseModel(torch.nn.Module):
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
self.memory_usage_factor_conds = ()
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@@ -168,11 +167,6 @@ class BaseModel(torch.nn.Module):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
if isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(ext.to(dtype))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
@@ -331,28 +325,19 @@ class BaseModel(torch.nn.Module):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
||||
|
||||
def memory_required(self, input_shape, cond_shapes={}):
|
||||
input_shapes = [input_shape]
|
||||
for c in self.memory_usage_factor_conds:
|
||||
shape = cond_shapes.get(c, None)
|
||||
if shape is not None and len(shape) > 0:
|
||||
input_shapes += shape
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
return {}
|
||||
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
||||
adm_inputs = []
|
||||
@@ -1062,11 +1047,6 @@ class WAN21(BaseModel):
|
||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||
|
||||
time_dim_concat = kwargs.get("time_dim_concat", None)
|
||||
if time_dim_concat is not None:
|
||||
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1082,25 +1062,20 @@ class WAN21_Vace(WAN21):
|
||||
vace_frames = kwargs.get("vace_frames", None)
|
||||
if vace_frames is None:
|
||||
noise_shape[1] = 32
|
||||
vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
|
||||
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||
|
||||
for i in range(0, vace_frames.shape[1], 16):
|
||||
vace_frames = vace_frames.clone()
|
||||
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
||||
|
||||
mask = kwargs.get("vace_mask", None)
|
||||
if mask is None:
|
||||
noise_shape[1] = 64
|
||||
mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
|
||||
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||
|
||||
vace_frames_out = []
|
||||
for j in range(len(vace_frames)):
|
||||
vf = vace_frames[j].clone()
|
||||
for i in range(0, vf.shape[1], 16):
|
||||
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
||||
vf = torch.cat([vf, mask[j]], dim=1)
|
||||
vace_frames_out.append(vf)
|
||||
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
||||
|
||||
vace_frames = torch.stack(vace_frames_out, dim=1)
|
||||
out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
|
||||
|
||||
vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
|
||||
vace_strength = kwargs.get("vace_strength", 1.0)
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
|
||||
@@ -620,9 +620,6 @@ def convert_config(unet_config):
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
if "conv_in.weight" not in state_dict:
|
||||
return None
|
||||
|
||||
match = {}
|
||||
transformer_depth = []
|
||||
|
||||
|
||||
@@ -297,16 +297,11 @@ except:
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
@@ -700,7 +695,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
return torch_dev
|
||||
|
||||
cpu_dev = torch.device("cpu")
|
||||
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
|
||||
if DISABLE_SMART_MEMORY:
|
||||
return cpu_dev
|
||||
|
||||
model_size = dtype_size(dtype) * parameters
|
||||
@@ -1262,9 +1257,6 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
if args.supports_fp8_compute:
|
||||
return True
|
||||
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
@@ -106,21 +104,6 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
for _, cs in conds.items():
|
||||
for cond in cs:
|
||||
for k, v in model.model.extra_conds_shapes(**cond).items():
|
||||
cond_shapes[k].append(v)
|
||||
if cond_shapes_min.get(k, None) is None:
|
||||
cond_shapes_min[k] = [v]
|
||||
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
|
||||
cond_shapes_min[k] = [v]
|
||||
|
||||
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
@@ -134,8 +117,9 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@@ -256,13 +256,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
for tt in batch_amount:
|
||||
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
|
||||
25
comfy/text_encoders/long_clipl.json
Normal file
25
comfy/text_encoders/long_clipl.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 49407,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 248,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.24.0",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
from .torch_compile import set_torch_compile_wrapper
|
||||
|
||||
__all__ = [
|
||||
"set_torch_compile_wrapper",
|
||||
]
|
||||
@@ -1,69 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
from comfy.patcher_extension import WrappersMP
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.patcher_extension import WrapperExecutor
|
||||
|
||||
|
||||
COMPILE_KEY = "torch.compile"
|
||||
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
|
||||
|
||||
|
||||
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
|
||||
'''
|
||||
Create a wrapper that will refer to the compiled_diffusion_model.
|
||||
'''
|
||||
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
|
||||
try:
|
||||
orig_modules = {}
|
||||
for key, value in compiled_module_dict.items():
|
||||
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
for key, value in orig_modules.items():
|
||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||
return apply_torch_compile_wrapper
|
||||
|
||||
|
||||
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
|
||||
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
|
||||
keys: list[str]=["diffusion_model"], *args, **kwargs):
|
||||
'''
|
||||
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
|
||||
|
||||
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
|
||||
When a list of keys is provided, it will perform torch.compile on only the selected modules.
|
||||
'''
|
||||
# clear out any other torch.compile wrappers
|
||||
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
|
||||
# if no keys, default to 'diffusion_model'
|
||||
if not keys:
|
||||
keys = ["diffusion_model"]
|
||||
# create kwargs dict that can be referenced later
|
||||
compile_kwargs = {
|
||||
"backend": backend,
|
||||
"options": options,
|
||||
"mode": mode,
|
||||
"fullgraph": fullgraph,
|
||||
"dynamic": dynamic,
|
||||
}
|
||||
# get a dict of compiled keys
|
||||
compiled_modules = {}
|
||||
for key in keys:
|
||||
compiled_modules[key] = torch.compile(
|
||||
model=model.get_model_object(key),
|
||||
**compile_kwargs,
|
||||
)
|
||||
# add torch.compile wrapper
|
||||
wrapper_func = apply_torch_compile_factory(
|
||||
compiled_module_dict=compiled_modules,
|
||||
)
|
||||
# store wrapper to run on BaseModel's apply_model function
|
||||
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
|
||||
# keep compile kwargs for reference
|
||||
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs
|
||||
@@ -1,855 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Literal
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, asdict
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
|
||||
|
||||
class InputBehavior(str, Enum):
|
||||
required = "required"
|
||||
optional = "optional"
|
||||
|
||||
|
||||
def is_class(obj):
|
||||
'''
|
||||
Returns True if is a class type.
|
||||
Returns False if is a class instance.
|
||||
'''
|
||||
return isinstance(obj, type)
|
||||
|
||||
|
||||
class NumberDisplay(str, Enum):
|
||||
number = "number"
|
||||
slider = "slider"
|
||||
|
||||
|
||||
class IO_V3:
|
||||
'''
|
||||
Base class for V3 Inputs and Outputs.
|
||||
'''
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls, io_type: IO | str, **kwargs):
|
||||
cls.io_type = io_type
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
class InputV3(IO_V3, io_type=None):
|
||||
'''
|
||||
Base class for a V3 Input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None):
|
||||
super().__init__()
|
||||
self.id = id
|
||||
self.display_name = display_name
|
||||
self.behavior = behavior
|
||||
self.tooltip = tooltip
|
||||
self.lazy = lazy
|
||||
|
||||
def as_dict_V1(self):
|
||||
return prune_dict({
|
||||
"display_name": self.display_name,
|
||||
"tooltip": self.tooltip,
|
||||
"lazy": self.lazy
|
||||
})
|
||||
|
||||
def get_io_type_V1(self):
|
||||
return self.io_type
|
||||
|
||||
class WidgetInputV3(InputV3, io_type=None):
|
||||
'''
|
||||
Base class for a V3 Input with widget.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||
default: Any=None,
|
||||
socketless: bool=None, widgetType: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip, lazy)
|
||||
self.default = default
|
||||
self.socketless = socketless
|
||||
self.widgetType = widgetType
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
"default": self.default,
|
||||
"socketless": self.socketless,
|
||||
"widgetType": self.widgetType,
|
||||
})
|
||||
|
||||
def CustomType(io_type: IO | str) -> type[IO_V3]:
|
||||
name = f"{io_type}_IO_V3"
|
||||
return type(name, (IO_V3,), {}, io_type=io_type)
|
||||
|
||||
def CustomInput(id: str, io_type: IO | str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3:
|
||||
'''
|
||||
Defines input for 'io_type'. Can be used to stand in for non-core types.
|
||||
'''
|
||||
input_kwargs = {
|
||||
"id": id,
|
||||
"display_name": display_name,
|
||||
"behavior": behavior,
|
||||
"tooltip": tooltip,
|
||||
"lazy": lazy,
|
||||
}
|
||||
return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs)
|
||||
|
||||
def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: str=None) -> OutputV3:
|
||||
'''
|
||||
Defines output for 'io_type'. Can be used to stand in for non-core types.
|
||||
'''
|
||||
input_kwargs = {
|
||||
"id": id,
|
||||
"display_name": display_name,
|
||||
"tooltip": tooltip,
|
||||
}
|
||||
return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs)
|
||||
|
||||
|
||||
class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
|
||||
'''
|
||||
Boolean input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||
default: bool=None, label_on: str=None, label_off: str=None,
|
||||
socketless: bool=None, widgetType: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||
self.label_on = label_on
|
||||
self.label_off = label_off
|
||||
self.default: bool
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
"label_on": self.label_on,
|
||||
"label_off": self.label_off,
|
||||
})
|
||||
|
||||
class IntegerInput(WidgetInputV3, io_type=IO.INT):
|
||||
'''
|
||||
Integer input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
self.control_after_generate = control_after_generate
|
||||
self.display_mode = display_mode
|
||||
self.default: int
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"step": self.step,
|
||||
"control_after_generate": self.control_after_generate,
|
||||
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
|
||||
})
|
||||
|
||||
class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
|
||||
'''
|
||||
Float input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||
self.default = default
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
self.round = round
|
||||
self.display_mode = display_mode
|
||||
self.default: float
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"step": self.step,
|
||||
"round": self.round,
|
||||
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
|
||||
})
|
||||
|
||||
class StringInput(WidgetInputV3, io_type=IO.STRING):
|
||||
'''
|
||||
String input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: int=None,
|
||||
socketless: bool=None, widgetType: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.default: str
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
"multiline": self.multiline,
|
||||
"placeholder": self.placeholder,
|
||||
})
|
||||
|
||||
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
|
||||
'''Combo input (dropdown).'''
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||
default: str=None, control_after_generate: bool=None,
|
||||
socketless: bool=None, widgetType: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||
self.multiselect = False
|
||||
self.options = options
|
||||
self.control_after_generate = control_after_generate
|
||||
self.default: str
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
"multiselect": self.multiselect,
|
||||
"options": self.options,
|
||||
"control_after_generate": self.control_after_generate,
|
||||
})
|
||||
|
||||
class MultiselectComboWidget(ComboInput, io_type=IO.COMBO):
|
||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
socketless: bool=None, widgetType: str=None):
|
||||
super().__init__(id, options, display_name, behavior, tooltip, lazy, default, control_after_generate, socketless, widgetType)
|
||||
self.multiselect = True
|
||||
self.placeholder = placeholder
|
||||
self.chip = chip
|
||||
self.default: list[str]
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
"multiselect": self.multiselect,
|
||||
"placeholder": self.placeholder,
|
||||
"chip": self.chip,
|
||||
})
|
||||
|
||||
class ImageInput(InputV3, io_type=IO.IMAGE):
|
||||
'''
|
||||
Image input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip)
|
||||
|
||||
class MaskInput(InputV3, io_type=IO.MASK):
|
||||
'''
|
||||
Mask input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip)
|
||||
|
||||
class LatentInput(InputV3, io_type=IO.LATENT):
|
||||
'''
|
||||
Latent input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||
super().__init__(id, display_name, behavior, tooltip)
|
||||
|
||||
class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"):
|
||||
'''
|
||||
Input that permits more than one input type.
|
||||
'''
|
||||
def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,):
|
||||
super().__init__(id, display_name, behavior, tooltip)
|
||||
self._io_types = io_types
|
||||
|
||||
@property
|
||||
def io_types(self) -> list[type[InputV3]]:
|
||||
'''
|
||||
Returns list of InputV3 class types permitted.
|
||||
'''
|
||||
io_types = []
|
||||
for x in self._io_types:
|
||||
if not is_class(x):
|
||||
io_types.append(type(x))
|
||||
else:
|
||||
io_types.append(x)
|
||||
return io_types
|
||||
|
||||
def get_io_type_V1(self):
|
||||
return ",".join(x.io_type for x in self.io_types)
|
||||
|
||||
|
||||
class OutputV3:
|
||||
def __init__(self, id: str, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
self.id = id
|
||||
self.display_name = display_name
|
||||
self.tooltip = tooltip
|
||||
self.is_output_list = is_output_list
|
||||
|
||||
def __init_subclass__(cls, io_type, **kwargs):
|
||||
cls.io_type = io_type
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
class IntegerOutput(OutputV3, io_type=IO.INT):
|
||||
pass
|
||||
|
||||
class FloatOutput(OutputV3, io_type=IO.FLOAT):
|
||||
pass
|
||||
|
||||
class StringOutput(OutputV3, io_type=IO.STRING):
|
||||
pass
|
||||
# def __init__(self, id: str, display_name: str=None, tooltip: str=None):
|
||||
# super().__init__(id, display_name, tooltip)
|
||||
|
||||
class ImageOutput(OutputV3, io_type=IO.IMAGE):
|
||||
pass
|
||||
|
||||
class MaskOutput(OutputV3, io_type=IO.MASK):
|
||||
pass
|
||||
|
||||
class LatentOutput(OutputV3, io_type=IO.LATENT):
|
||||
pass
|
||||
|
||||
|
||||
class DynamicInput(InputV3, io_type=None):
|
||||
'''
|
||||
Abstract class for dynamic input registration.
|
||||
'''
|
||||
def __init__(self, io_type: str, id: str, display_name: str=None):
|
||||
super().__init__(io_type, id, display_name)
|
||||
|
||||
class DynamicOutput(OutputV3, io_type=None):
|
||||
'''
|
||||
Abstract class for dynamic output registration.
|
||||
'''
|
||||
def __init__(self, io_type: str, id: str, display_name: str=None):
|
||||
super().__init__(io_type, id, display_name)
|
||||
|
||||
class AutoGrowDynamicInput(DynamicInput, io_type="COMFY_MULTIGROW_V3"):
|
||||
'''
|
||||
Dynamic Input that adds another template_input each time one is provided.
|
||||
|
||||
Additional inputs are forced to have 'InputBehavior.optional'.
|
||||
'''
|
||||
def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None):
|
||||
super().__init__("AutoGrowDynamicInput", id)
|
||||
self.template_input = template_input
|
||||
if min is not None:
|
||||
assert(min >= 1)
|
||||
if max is not None:
|
||||
assert(max >= 1)
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
class ComboDynamicInput(DynamicInput, io_type="COMFY_COMBODYNAMIC_V3"):
|
||||
def __init__(self, id: str):
|
||||
pass
|
||||
|
||||
AutoGrowDynamicInput(id="dynamic", template_input=ImageInput(id="image"))
|
||||
|
||||
|
||||
class Hidden(str, Enum):
|
||||
'''
|
||||
Enumerator for requesting hidden variables in nodes.
|
||||
'''
|
||||
|
||||
unique_id = "UNIQUE_ID"
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
prompt = "PROMPT"
|
||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||
extra_pnginfo = "EXTRA_PNGINFO"
|
||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||
dynprompt = "DYNPROMPT"
|
||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||
auth_token_comfy_org = "AUTH_TOKEN_COMFY_ORG"
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
api_key_comfy_org = "API_KEY_COMFY_ORG"
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeInfoV1:
|
||||
input: dict=None
|
||||
input_order: dict[str, list[str]]=None
|
||||
output: list[str]=None
|
||||
output_is_list: list[bool]=None
|
||||
output_name: list[str]=None
|
||||
output_tooltips: list[str]=None
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
python_module: Any=None
|
||||
category: str=None
|
||||
output_node: bool=None
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
api_node: bool=None
|
||||
|
||||
|
||||
def as_pruned_dict(dataclass_obj):
|
||||
'''Return dict of dataclass object with pruned None values.'''
|
||||
return prune_dict(asdict(dataclass_obj))
|
||||
|
||||
def prune_dict(d: dict):
|
||||
return {k: v for k,v in d.items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchemaV3:
|
||||
"""Definition of V3 node properties."""
|
||||
|
||||
node_id: str
|
||||
"""ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
|
||||
display_name: str = None
|
||||
"""Display name of node."""
|
||||
category: str = "sd"
|
||||
"""The category of the node, as per the "Add Node" menu."""
|
||||
inputs: list[InputV3]=None
|
||||
outputs: list[OutputV3]=None
|
||||
hidden: list[Hidden]=None
|
||||
description: str=""
|
||||
"""Node description, shown as a tooltip when hovering over the node."""
|
||||
is_input_list: bool = False
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
|
||||
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||
|
||||
From the docs:
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
is_output_node: bool=False
|
||||
"""Flags this node as an output node, causing any inputs it requires to be executed.
|
||||
|
||||
If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
From the docs:
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
||||
"""
|
||||
is_deprecated: bool=False
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
is_experimental: bool=False
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
is_api_node: bool=False
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
# class SchemaV3Class:
|
||||
# def __init__(self,
|
||||
# node_id: str,
|
||||
# node_name: str,
|
||||
# category: str,
|
||||
# inputs: list[InputV3],
|
||||
# outputs: list[OutputV3]=None,
|
||||
# hidden: list[Hidden]=None,
|
||||
# description: str="",
|
||||
# is_input_list: bool = False,
|
||||
# is_output_node: bool=False,
|
||||
# is_deprecated: bool=False,
|
||||
# is_experimental: bool=False,
|
||||
# is_api_node: bool=False,
|
||||
# ):
|
||||
# self.node_id = node_id
|
||||
# """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
|
||||
# self.node_name = node_name
|
||||
# """Display name of node."""
|
||||
# self.category = category
|
||||
# """The category of the node, as per the "Add Node" menu."""
|
||||
# self.inputs = inputs
|
||||
# self.outputs = outputs
|
||||
# self.hidden = hidden
|
||||
# self.description = description
|
||||
# """Node description, shown as a tooltip when hovering over the node."""
|
||||
# self.is_input_list = is_input_list
|
||||
# """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
|
||||
# All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||
|
||||
# From the docs:
|
||||
|
||||
# A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
# """
|
||||
# self.is_output_node = is_output_node
|
||||
# """Flags this node as an output node, causing any inputs it requires to be executed.
|
||||
|
||||
# If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||
|
||||
# OUTPUT_NODE = True
|
||||
|
||||
# From the docs:
|
||||
|
||||
# By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
||||
# """
|
||||
# self.is_deprecated = is_deprecated
|
||||
# """Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
# self.is_experimental = is_experimental
|
||||
# """Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
# self.is_api_node = is_api_node
|
||||
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
|
||||
class classproperty(object):
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
def __get__(self, obj, owner):
|
||||
return self.f(owner)
|
||||
|
||||
|
||||
class ComfyNodeV3(ABC):
|
||||
"""Common base class for all V3 nodes."""
|
||||
|
||||
RELATIVE_PYTHON_MODULE = None
|
||||
#############################################
|
||||
# V1 Backwards Compatibility code
|
||||
#--------------------------------------------
|
||||
_DESCRIPTION = None
|
||||
@classproperty
|
||||
def DESCRIPTION(cls):
|
||||
if cls._DESCRIPTION is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._DESCRIPTION
|
||||
|
||||
_CATEGORY = None
|
||||
@classproperty
|
||||
def CATEGORY(cls):
|
||||
if cls._CATEGORY is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._CATEGORY
|
||||
|
||||
_EXPERIMENTAL = None
|
||||
@classproperty
|
||||
def EXPERIMENTAL(cls):
|
||||
if cls._EXPERIMENTAL is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._EXPERIMENTAL
|
||||
|
||||
_DEPRECATED = None
|
||||
@classproperty
|
||||
def DEPRECATED(cls):
|
||||
if cls._DEPRECATED is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._DEPRECATED
|
||||
|
||||
_API_NODE = None
|
||||
@classproperty
|
||||
def API_NODE(cls):
|
||||
if cls._API_NODE is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._API_NODE
|
||||
|
||||
_OUTPUT_NODE = None
|
||||
@classproperty
|
||||
def OUTPUT_NODE(cls):
|
||||
if cls._OUTPUT_NODE is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._OUTPUT_NODE
|
||||
|
||||
_INPUT_IS_LIST = None
|
||||
@classproperty
|
||||
def INPUT_IS_LIST(cls):
|
||||
if cls._INPUT_IS_LIST is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._INPUT_IS_LIST
|
||||
_OUTPUT_IS_LIST = None
|
||||
|
||||
@classproperty
|
||||
def OUTPUT_IS_LIST(cls):
|
||||
if cls._OUTPUT_IS_LIST is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._OUTPUT_IS_LIST
|
||||
|
||||
_RETURN_TYPES = None
|
||||
@classproperty
|
||||
def RETURN_TYPES(cls):
|
||||
if cls._RETURN_TYPES is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._RETURN_TYPES
|
||||
|
||||
_RETURN_NAMES = None
|
||||
@classproperty
|
||||
def RETURN_NAMES(cls):
|
||||
if cls._RETURN_NAMES is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._RETURN_NAMES
|
||||
|
||||
_OUTPUT_TOOLTIPS = None
|
||||
@classproperty
|
||||
def OUTPUT_TOOLTIPS(cls):
|
||||
if cls._OUTPUT_TOOLTIPS is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._OUTPUT_TOOLTIPS
|
||||
|
||||
FUNCTION = "execute"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
schema = cls.DEFINE_SCHEMA()
|
||||
# for V1, make inputs be a dict with potential keys {required, optional, hidden}
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
if schema.inputs:
|
||||
for i in schema.inputs:
|
||||
input.setdefault(i.behavior.value, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1())
|
||||
if schema.hidden:
|
||||
for hidden in schema.hidden:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
return input
|
||||
|
||||
@classmethod
|
||||
def GET_SCHEMA(cls) -> SchemaV3:
|
||||
schema = cls.DEFINE_SCHEMA()
|
||||
if cls._DESCRIPTION is None:
|
||||
cls._DESCRIPTION = schema.description
|
||||
if cls._CATEGORY is None:
|
||||
cls._CATEGORY = schema.category
|
||||
if cls._EXPERIMENTAL is None:
|
||||
cls._EXPERIMENTAL = schema.is_experimental
|
||||
if cls._DEPRECATED is None:
|
||||
cls._DEPRECATED = schema.is_deprecated
|
||||
if cls._API_NODE is None:
|
||||
cls._API_NODE = schema.is_api_node
|
||||
if cls._OUTPUT_NODE is None:
|
||||
cls._OUTPUT_NODE = schema.is_output_node
|
||||
if cls._INPUT_IS_LIST is None:
|
||||
cls._INPUT_IS_LIST = schema.is_input_list
|
||||
|
||||
if cls._RETURN_TYPES is None:
|
||||
output = []
|
||||
output_name = []
|
||||
output_is_list = []
|
||||
output_tooltips = []
|
||||
if schema.outputs:
|
||||
for o in schema.outputs:
|
||||
output.append(o.io_type)
|
||||
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||
output_is_list.append(o.is_output_list)
|
||||
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||
|
||||
cls._RETURN_TYPES = output
|
||||
cls._RETURN_NAMES = output_name
|
||||
cls._OUTPUT_IS_LIST = output_is_list
|
||||
cls._OUTPUT_TOOLTIPS = output_tooltips
|
||||
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V1(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
# get V1 inputs
|
||||
input = cls.INPUT_TYPES()
|
||||
|
||||
# create separate lists from output fields
|
||||
output = []
|
||||
output_is_list = []
|
||||
output_name = []
|
||||
output_tooltips = []
|
||||
if schema.outputs:
|
||||
for o in schema.outputs:
|
||||
output.append(o.io_type)
|
||||
output_is_list.append(o.is_output_list)
|
||||
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||
|
||||
info = NodeInfoV1(
|
||||
input=input,
|
||||
input_order={key: list(value.keys()) for (key, value) in input.items()},
|
||||
output=output,
|
||||
output_is_list=output_is_list,
|
||||
output_name=output_name,
|
||||
output_tooltips=output_tooltips,
|
||||
name=schema.node_id,
|
||||
display_name=schema.display_name,
|
||||
category=schema.category,
|
||||
description=schema.description,
|
||||
output_node=schema.is_output_node,
|
||||
deprecated=schema.is_deprecated,
|
||||
experimental=schema.is_experimental,
|
||||
api_node=schema.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||
)
|
||||
return asdict(info)
|
||||
#--------------------------------------------
|
||||
#############################################
|
||||
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
# TODO: finish
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def DEFINE_SCHEMA(cls) -> SchemaV3:
|
||||
"""
|
||||
Override this function with one that returns a SchemaV3 instance.
|
||||
"""
|
||||
return None
|
||||
DEFINE_SCHEMA = None
|
||||
|
||||
def __init__(self):
|
||||
if self.DEFINE_SCHEMA is None:
|
||||
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, **kwargs) -> NodeOutput:
|
||||
pass
|
||||
|
||||
|
||||
# class ReturnedInputs:
|
||||
# def __init__(self):
|
||||
# pass
|
||||
|
||||
# class ReturnedOutputs:
|
||||
# def __init__(self):
|
||||
# pass
|
||||
|
||||
|
||||
class NodeOutput:
|
||||
'''
|
||||
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
|
||||
'''
|
||||
def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
|
||||
self.args = args
|
||||
self.ui = ui
|
||||
self.expand = expand
|
||||
self.block_execution = block_execution
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self.args if len(self.args) > 0 else None
|
||||
|
||||
|
||||
class SavedResult:
|
||||
def __init__(self, filename: str, subfolder: str, type: Literal["input", "output", "temp"]):
|
||||
self.filename = filename
|
||||
self.subfolder = subfolder
|
||||
self.type = type
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"filename": self.filename,
|
||||
"subfolder": self.subfolder,
|
||||
"type": self.type
|
||||
}
|
||||
|
||||
class UIOutput(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_dict(self) -> dict:
|
||||
... # TODO: finish
|
||||
|
||||
class UIImages(UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs):
|
||||
self.values = values
|
||||
self.animated = animated
|
||||
|
||||
def as_dict(self):
|
||||
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||
return {
|
||||
"images": values,
|
||||
"animated": (self.animated,)
|
||||
}
|
||||
|
||||
class UILatents(UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||
self.values = values
|
||||
|
||||
def as_dict(self):
|
||||
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||
return {
|
||||
"latents": values,
|
||||
}
|
||||
|
||||
class UIAudio(UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||
self.values = values
|
||||
|
||||
def as_dict(self):
|
||||
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||
return {
|
||||
"audio": values,
|
||||
}
|
||||
|
||||
class UI3D(UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||
self.values = values
|
||||
|
||||
def as_dict(self):
|
||||
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||
return {
|
||||
"3d": values,
|
||||
}
|
||||
|
||||
class UIText(UIOutput):
|
||||
def __init__(self, value: str, **kwargs):
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return {"text": (self.value,)}
|
||||
|
||||
|
||||
class TestNode(ComfyNodeV3):
|
||||
SCHEMA = SchemaV3(
|
||||
node_id="TestNode_v3",
|
||||
display_name="Test Node (V3)",
|
||||
category="v3_test",
|
||||
inputs=[IntegerInput("my_int"),
|
||||
#AutoGrowDynamicInput("growing", ImageInput),
|
||||
MaskInput("thing"),
|
||||
],
|
||||
outputs=[ImageOutput("image_output")],
|
||||
hidden=[Hidden.api_key_comfy_org, Hidden.auth_token_comfy_org, Hidden.unique_id]
|
||||
)
|
||||
|
||||
# @classmethod
|
||||
# def GET_SCHEMA(cls):
|
||||
# return cls.SCHEMA
|
||||
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
return cls.SCHEMA
|
||||
|
||||
def execute(**kwargs):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("hello there")
|
||||
inputs: list[InputV3] = [
|
||||
IntegerInput("my_int"),
|
||||
CustomInput("xyz", "XYZ"),
|
||||
CustomInput("model1", "MODEL_M"),
|
||||
ImageInput("my_image"),
|
||||
FloatInput("my_float"),
|
||||
MultitypedInput("my_inputs", [CustomType("MODEL_M"), CustomType("XYZ")]),
|
||||
]
|
||||
|
||||
outputs: list[OutputV3] = [
|
||||
ImageOutput("image"),
|
||||
CustomOutput("xyz", "XYZ")
|
||||
]
|
||||
|
||||
for c in inputs:
|
||||
if isinstance(c, MultitypedInput):
|
||||
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}")
|
||||
print(c.get_io_type_V1())
|
||||
else:
|
||||
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
|
||||
|
||||
for c in outputs:
|
||||
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
|
||||
|
||||
zz = TestNode()
|
||||
print(zz.GET_NODE_INFO_V1())
|
||||
|
||||
# aa = NodeInfoV1()
|
||||
# print(asdict(aa))
|
||||
# print(as_pruned_dict(aa))
|
||||
@@ -18,8 +18,6 @@ Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to
|
||||
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
||||
```
|
||||
|
||||
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
|
||||
|
||||
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
||||
|
||||
### Redocly Instructions
|
||||
@@ -30,7 +28,7 @@ When developing locally, use the `redocly-dev.yaml` file to generate pydantic mo
|
||||
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
||||
|
||||
```bash
|
||||
# Download the OpenAPI file from staging server.
|
||||
# Download the OpenAPI file from prod server.
|
||||
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
||||
|
||||
# Filter out unneeded API definitions.
|
||||
@@ -41,25 +39,3 @@ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_no
|
||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
```
|
||||
|
||||
|
||||
# Merging to Master
|
||||
|
||||
Before merging to comfyanonymous/ComfyUI master, follow these steps:
|
||||
|
||||
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
|
||||
1. Make sure the ComfyUI API is deployed to prod with your changes.
|
||||
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
|
||||
|
||||
```bash
|
||||
# Download the OpenAPI file from prod server.
|
||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||
|
||||
# Filter out unneeded API definitions.
|
||||
npm install -g @redocly/cli
|
||||
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||
|
||||
# Generate the pydantic datamodels for validation.
|
||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
```
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
from typing import Optional, Union
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
@@ -215,7 +214,6 @@ def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||
@@ -320,27 +318,11 @@ def tensor_to_data_uri(
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: str,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single file to ComfyUI API and returns its download URL.
|
||||
@@ -375,33 +357,9 @@ def upload_file_to_comfyapi(
|
||||
return response.download_url
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: VideoInput,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = io.BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
@@ -503,7 +461,7 @@ def audio_ndarray_to_bytesio(
|
||||
|
||||
def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
@@ -530,25 +488,8 @@ def upload_audio_to_comfyapi(
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||
) -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
@@ -613,24 +554,17 @@ def upload_images_to_comfyapi(
|
||||
return download_urls
|
||||
|
||||
|
||||
def resize_mask_to_image(
|
||||
mask: torch.Tensor,
|
||||
image: torch.Tensor,
|
||||
upscale_method="nearest-exact",
|
||||
crop="disabled",
|
||||
allow_gradient=True,
|
||||
add_channel_dim=False,
|
||||
):
|
||||
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
|
||||
upscale_method="nearest-exact", crop="disabled",
|
||||
allow_gradient=True, add_channel_dim=False):
|
||||
"""
|
||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||
"""
|
||||
_, H, W, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1, 1)
|
||||
mask = common_upscale(
|
||||
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
||||
)
|
||||
mask = mask.movedim(1, -1)
|
||||
mask = mask.movedim(-1,1)
|
||||
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
|
||||
mask = mask.movedim(1,-1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
@@ -638,41 +572,12 @@ def resize_mask_to_image(
|
||||
return mask
|
||||
|
||||
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(
|
||||
image1: torch.Tensor, image2: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
|
||||
if not string:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -108,24 +108,6 @@ class BFLFluxProGenerateRequest(BaseModel):
|
||||
# )
|
||||
|
||||
|
||||
class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
|
||||
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
|
||||
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
|
||||
@@ -139,7 +139,7 @@ class EmptyRequest(BaseModel):
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
content_type: str | None = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
@@ -327,9 +327,7 @@ class ApiClient:
|
||||
ApiServerError: If the API server is unreachable but internet is working
|
||||
Exception: For other request failures
|
||||
"""
|
||||
# Use urljoin but ensure path is relative to avoid absolute path behavior
|
||||
relative_path = path.lstrip('/')
|
||||
url = urljoin(self.base_url, relative_path)
|
||||
url = urljoin(self.base_url, path)
|
||||
self.check_auth(self.auth_token, self.comfy_api_key)
|
||||
# Combine default headers with any provided headers
|
||||
request_headers = self.get_headers()
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Rodin3DGenerateRequest(BaseModel):
|
||||
seed: int = Field(..., description="seed_")
|
||||
tier: str = Field(..., description="Tier of generation.")
|
||||
material: str = Field(..., description="The material type.")
|
||||
quality: str = Field(..., description="The generation quality of the mesh.")
|
||||
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
|
||||
|
||||
class GenerateJobsData(BaseModel):
|
||||
uuids: List[str] = Field(..., description="str LIST")
|
||||
subscription_key: str = Field(..., description="subscription key")
|
||||
|
||||
class Rodin3DGenerateResponse(BaseModel):
|
||||
message: Optional[str] = Field(None, description="Return message.")
|
||||
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
|
||||
submit_time: Optional[str] = Field(None, description="Submit Time")
|
||||
uuid: Optional[str] = Field(None, description="Task str")
|
||||
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""
|
||||
Status for jobs
|
||||
"""
|
||||
Done = "Done"
|
||||
Failed = "Failed"
|
||||
Generating = "Generating"
|
||||
Waiting = "Waiting"
|
||||
|
||||
class Rodin3DCheckStatusRequest(BaseModel):
|
||||
subscription_key: str = Field(..., description="subscription from generate endpoint")
|
||||
|
||||
class JobItem(BaseModel):
|
||||
uuid: str = Field(..., description="uuid")
|
||||
status: JobStatus = Field(...,description="Status Currently")
|
||||
|
||||
class Rodin3DCheckStatusResponse(BaseModel):
|
||||
jobs: List[JobItem] = Field(..., description="Job status List")
|
||||
|
||||
class Rodin3DDownloadRequest(BaseModel):
|
||||
task_uuid: str = Field(..., description="Task str")
|
||||
|
||||
class RodinResourceItem(BaseModel):
|
||||
url: str = Field(..., description="Download Url")
|
||||
name: str = Field(..., description="File name with ext")
|
||||
|
||||
class Rodin3DDownloadResponse(BaseModel):
|
||||
list: List[RodinResourceItem] = Field(..., description="Source List")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api_nodes.apis import (
|
||||
TripoModelVersion,
|
||||
TripoTextureQuality,
|
||||
)
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
class TripoStyle(str, Enum):
|
||||
PERSON_TO_CARTOON = "person:person2cartoon"
|
||||
ANIMAL_VENOM = "animal:venom"
|
||||
OBJECT_CLAY = "object:clay"
|
||||
OBJECT_STEAMPUNK = "object:steampunk"
|
||||
OBJECT_CHRISTMAS = "object:christmas"
|
||||
OBJECT_BARBIE = "object:barbie"
|
||||
GOLD = "gold"
|
||||
ANCIENT_BRONZE = "ancient_bronze"
|
||||
NONE = "None"
|
||||
|
||||
class TripoTaskType(str, Enum):
|
||||
TEXT_TO_MODEL = "text_to_model"
|
||||
IMAGE_TO_MODEL = "image_to_model"
|
||||
MULTIVIEW_TO_MODEL = "multiview_to_model"
|
||||
TEXTURE_MODEL = "texture_model"
|
||||
REFINE_MODEL = "refine_model"
|
||||
ANIMATE_PRERIGCHECK = "animate_prerigcheck"
|
||||
ANIMATE_RIG = "animate_rig"
|
||||
ANIMATE_RETARGET = "animate_retarget"
|
||||
STYLIZE_MODEL = "stylize_model"
|
||||
CONVERT_MODEL = "convert_model"
|
||||
|
||||
class TripoTextureAlignment(str, Enum):
|
||||
ORIGINAL_IMAGE = "original_image"
|
||||
GEOMETRY = "geometry"
|
||||
|
||||
class TripoOrientation(str, Enum):
|
||||
ALIGN_IMAGE = "align_image"
|
||||
DEFAULT = "default"
|
||||
|
||||
class TripoOutFormat(str, Enum):
|
||||
GLB = "glb"
|
||||
FBX = "fbx"
|
||||
|
||||
class TripoTopology(str, Enum):
|
||||
BIP = "bip"
|
||||
QUAD = "quad"
|
||||
|
||||
class TripoSpec(str, Enum):
|
||||
MIXAMO = "mixamo"
|
||||
TRIPO = "tripo"
|
||||
|
||||
class TripoAnimation(str, Enum):
|
||||
IDLE = "preset:idle"
|
||||
WALK = "preset:walk"
|
||||
CLIMB = "preset:climb"
|
||||
JUMP = "preset:jump"
|
||||
RUN = "preset:run"
|
||||
SLASH = "preset:slash"
|
||||
SHOOT = "preset:shoot"
|
||||
HURT = "preset:hurt"
|
||||
FALL = "preset:fall"
|
||||
TURN = "preset:turn"
|
||||
|
||||
class TripoStylizeStyle(str, Enum):
|
||||
LEGO = "lego"
|
||||
VOXEL = "voxel"
|
||||
VORONOI = "voronoi"
|
||||
MINECRAFT = "minecraft"
|
||||
|
||||
class TripoConvertFormat(str, Enum):
|
||||
GLTF = "GLTF"
|
||||
USDZ = "USDZ"
|
||||
FBX = "FBX"
|
||||
OBJ = "OBJ"
|
||||
STL = "STL"
|
||||
_3MF = "3MF"
|
||||
|
||||
class TripoTextureFormat(str, Enum):
|
||||
BMP = "BMP"
|
||||
DPX = "DPX"
|
||||
HDR = "HDR"
|
||||
JPEG = "JPEG"
|
||||
OPEN_EXR = "OPEN_EXR"
|
||||
PNG = "PNG"
|
||||
TARGA = "TARGA"
|
||||
TIFF = "TIFF"
|
||||
WEBP = "WEBP"
|
||||
|
||||
class TripoTaskStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
UNKNOWN = "unknown"
|
||||
BANNED = "banned"
|
||||
EXPIRED = "expired"
|
||||
|
||||
class TripoFileTokenReference(BaseModel):
|
||||
type: Optional[str] = Field(None, description='The type of the reference')
|
||||
file_token: str
|
||||
|
||||
class TripoUrlReference(BaseModel):
|
||||
type: Optional[str] = Field(None, description='The type of the reference')
|
||||
url: str
|
||||
|
||||
class TripoObjectStorage(BaseModel):
|
||||
bucket: str
|
||||
key: str
|
||||
|
||||
class TripoObjectReference(BaseModel):
|
||||
type: str
|
||||
object: TripoObjectStorage
|
||||
|
||||
class TripoFileEmptyReference(BaseModel):
|
||||
pass
|
||||
|
||||
class TripoFileReference(RootModel):
|
||||
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
|
||||
|
||||
class TripoGetStsTokenRequest(BaseModel):
|
||||
format: str = Field(..., description='The format of the image')
|
||||
|
||||
class TripoTextToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
|
||||
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
|
||||
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
image_seed: Optional[int] = Field(None, description='The seed for the text')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
style: Optional[TripoStyle] = None
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
|
||||
class TripoImageToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
|
||||
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
|
||||
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
|
||||
class TripoMultiviewToModelRequest(BaseModel):
|
||||
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
|
||||
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
|
||||
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
|
||||
class TripoTextureModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
|
||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||
|
||||
class TripoRefineModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
|
||||
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
|
||||
|
||||
class TripoAnimatePrerigcheckRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
|
||||
class TripoAnimateRigRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
|
||||
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
|
||||
|
||||
class TripoAnimateRetargetRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
animation: TripoAnimation = Field(..., description='The animation to apply')
|
||||
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
|
||||
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
|
||||
|
||||
class TripoStylizeModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
|
||||
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
block_size: Optional[int] = Field(80, description='The block size for stylization')
|
||||
|
||||
class TripoConvertModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
|
||||
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
|
||||
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
|
||||
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
|
||||
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
|
||||
texture_size: Optional[int] = Field(4096, description='The size of the texture')
|
||||
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
|
||||
|
||||
class TripoTaskRequest(RootModel):
|
||||
root: Union[
|
||||
TripoTextToModelRequest,
|
||||
TripoImageToModelRequest,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoRefineModelRequest,
|
||||
TripoAnimatePrerigcheckRequest,
|
||||
TripoAnimateRigRequest,
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoStylizeModelRequest,
|
||||
TripoConvertModelRequest
|
||||
]
|
||||
|
||||
class TripoTaskOutput(BaseModel):
|
||||
model: Optional[str] = Field(None, description='URL to the model')
|
||||
base_model: Optional[str] = Field(None, description='URL to the base model')
|
||||
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
|
||||
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
|
||||
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
|
||||
|
||||
class TripoTask(BaseModel):
|
||||
task_id: str = Field(..., description='The task ID')
|
||||
type: Optional[str] = Field(None, description='The type of task')
|
||||
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
|
||||
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
|
||||
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
|
||||
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
|
||||
create_time: Optional[int] = Field(None, description='The creation time of the task')
|
||||
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
|
||||
queue_position: Optional[int] = Field(None, description='The position in the queue')
|
||||
|
||||
class TripoTaskResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: TripoTask = Field(..., description='The task data')
|
||||
|
||||
class TripoGeneralResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: Dict[str, str] = Field(..., description='The task ID data')
|
||||
|
||||
class TripoBalanceData(BaseModel):
|
||||
balance: float = Field(..., description='The account balance')
|
||||
frozen: float = Field(..., description='The frozen balance')
|
||||
|
||||
class TripoBalanceResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: TripoBalanceData = Field(..., description='The balance data')
|
||||
|
||||
class TripoErrorResponse(BaseModel):
|
||||
code: int = Field(..., description='The error code')
|
||||
message: str = Field(..., description='The error message')
|
||||
suggestion: str = Field(..., description='The suggestion for fixing the error')
|
||||
@@ -1,6 +1,6 @@
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Union, Optional
|
||||
from typing import Union
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLStatus,
|
||||
@@ -9,7 +9,6 @@ from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLFluxCannyImageRequest,
|
||||
BFLFluxDepthImageRequest,
|
||||
BFLFluxProGenerateRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
)
|
||||
@@ -270,158 +269,6 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxKontextProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
MAXIMUM_RATIO = 4 / 1
|
||||
MINIMUM_RATIO_STR = "1:4"
|
||||
MAXIMUM_RATIO_STR = "4:1"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation - specify what and how to edit.",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 3.0,
|
||||
"min": 0.1,
|
||||
"max": 99.0,
|
||||
"step": 0.1,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 1,
|
||||
"max": 150,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1234,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"input_image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return True
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: Optional[torch.Tensor]=None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if input_image is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=self.BFL_PATH,
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxKontextProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxKontextProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
guidance=round(guidance, 1),
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
input_image=(
|
||||
input_image
|
||||
if input_image is None
|
||||
else convert_image_to_base64(input_image)
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||
|
||||
|
||||
class FluxProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
@@ -1067,8 +914,6 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FluxProUltraImageNode": FluxProUltraImageNode,
|
||||
# "FluxProImageNode": FluxProImageNode,
|
||||
"FluxKontextProImageNode": FluxKontextProImageNode,
|
||||
"FluxKontextMaxImageNode": FluxKontextMaxImageNode,
|
||||
"FluxProExpandNode": FluxProExpandNode,
|
||||
"FluxProFillNode": FluxProFillNode,
|
||||
"FluxProCannyNode": FluxProCannyNode,
|
||||
@@ -1079,8 +924,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
||||
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
||||
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
|
||||
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
|
||||
"FluxProExpandNode": "Flux.1 Expand Image",
|
||||
"FluxProFillNode": "Flux.1 Fill Image",
|
||||
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
||||
|
||||
@@ -1,446 +0,0 @@
|
||||
"""
|
||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
|
||||
import torch
|
||||
|
||||
import folder_paths
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from server import PromptServer
|
||||
from comfy_api_nodes.apis import (
|
||||
GeminiContent,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiInlineData,
|
||||
GeminiPart,
|
||||
GeminiMimeType,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
validate_string,
|
||||
audio_to_base64_string,
|
||||
video_to_base64_string,
|
||||
tensor_to_base64_string,
|
||||
)
|
||||
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
|
||||
|
||||
class GeminiModel(str, Enum):
|
||||
"""
|
||||
Gemini Model Names allowed by comfy-api
|
||||
"""
|
||||
|
||||
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||
|
||||
|
||||
def get_gemini_endpoint(
|
||||
model: GeminiModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
|
||||
This node allows users to interact with Google's Gemini AI models, providing
|
||||
multimodal inputs (text, images, audio, video, files) to generate coherent
|
||||
text responses. The node works with the latest Gemini models, handling the
|
||||
API communication and response parsing.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiModel],
|
||||
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
"audio": (
|
||||
IO.AUDIO,
|
||||
{
|
||||
"tooltip": "Optional audio to use as context for the model.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
"video": (
|
||||
IO.VIDEO,
|
||||
{
|
||||
"tooltip": "Optional video to use as context for the model.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
API_NODE = True
|
||||
|
||||
def get_parts_from_response(
|
||||
self, response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
def get_parts_by_type(
|
||||
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
part_type: Type of parts to extract ("text" or a MIME type).
|
||||
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in self.get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
||||
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
|
||||
"""
|
||||
Extract and concatenate all text parts from the response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
Combined text from all text parts in the response.
|
||||
"""
|
||||
parts = self.get_parts_by_type(response, "text")
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert video input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
video_input: Video tensor from ComfyUI.
|
||||
**kwargs: Additional arguments to pass to the conversion function.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded video.
|
||||
"""
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input,
|
||||
container_format=VideoContainer.MP4,
|
||||
codec=VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.video_mp4,
|
||||
data=base_64_string,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert audio input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded audio.
|
||||
"""
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = {
|
||||
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
"sample_rate": audio_input["sample_rate"],
|
||||
}
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
)
|
||||
audio_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.audio_mp3,
|
||||
data=audio_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
def create_text_part(self, text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: GeminiModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
audio: Optional[IO.AUDIO] = None,
|
||||
video: Optional[IO.VIDEO] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [self.create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = self.create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if audio is not None:
|
||||
parts.extend(self.create_audio_parts(audio))
|
||||
if video is not None:
|
||||
parts.extend(self.create_video_parts(video))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
response = SynchronousOperation(
|
||||
endpoint=get_gemini_endpoint(model),
|
||||
request=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=parts,
|
||||
)
|
||||
]
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
|
||||
# Get result output
|
||||
output_text = self.get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
|
||||
|
||||
return (output_text or "Empty response from Gemini model...",)
|
||||
|
||||
|
||||
class GeminiInputFiles(ComfyNodeABC):
|
||||
"""
|
||||
Loads and formats input files for use with the Gemini API.
|
||||
|
||||
This node allows users to include text (.txt) and PDF (.pdf) files as input
|
||||
context for the Gemini model. Files are converted to the appropriate format
|
||||
required by the API and can be chained together to include multiple files
|
||||
in a single request.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
"""
|
||||
For details about the supported file input types, see:
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
input_files = [
|
||||
f
|
||||
for f in os.scandir(input_dir)
|
||||
if f.is_file()
|
||||
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
|
||||
and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE
|
||||
]
|
||||
input_files = sorted(input_files, key=lambda x: x.name)
|
||||
input_files = [f.name for f in input_files]
|
||||
return {
|
||||
"required": {
|
||||
"file": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||
"options": input_files,
|
||||
"default": input_files[0] if input_files else None,
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"GEMINI_INPUT_FILES": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
|
||||
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
|
||||
FUNCTION = "prepare_files"
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
|
||||
def create_file_part(self, file_path: str) -> GeminiPart:
|
||||
mime_type = (
|
||||
GeminiMimeType.pdf
|
||||
if file_path.endswith(".pdf")
|
||||
else GeminiMimeType.text_plain
|
||||
)
|
||||
# Use base64 string directly, not the data URI
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
import base64
|
||||
base64_str = base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
return GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=mime_type,
|
||||
data=base64_str,
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_files(
|
||||
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
|
||||
) -> tuple[list[GeminiPart]]:
|
||||
"""
|
||||
Loads and formats input files for Gemini API.
|
||||
"""
|
||||
file_path = folder_paths.get_annotated_filepath(file)
|
||||
input_file_content = self.create_file_part(file_path)
|
||||
files = [input_file_content] + GEMINI_INPUT_FILES
|
||||
return (files,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GeminiNode": GeminiNode,
|
||||
"GeminiInputFiles": GeminiInputFiles,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GeminiNode": "Google Gemini",
|
||||
"GeminiInputFiles": "Gemini Input Files",
|
||||
}
|
||||
@@ -1,86 +1,29 @@
|
||||
import io
|
||||
from typing import TypedDict, Optional
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from inspect import cleandoc
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from server import PromptServer
|
||||
import folder_paths
|
||||
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
OpenAICreateResponse,
|
||||
OpenAIResponse,
|
||||
CreateModelResponseProperties,
|
||||
Item,
|
||||
Includable,
|
||||
OutputContent,
|
||||
InputImageContent,
|
||||
Detail,
|
||||
InputTextContent,
|
||||
InputMessage,
|
||||
InputMessageContentList,
|
||||
InputContent,
|
||||
InputFileContent,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_and_cast_response,
|
||||
validate_string,
|
||||
tensor_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
|
||||
|
||||
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||
|
||||
|
||||
class HistoryEntry(TypedDict):
|
||||
"""Type definition for a single history entry in the chat."""
|
||||
|
||||
prompt: str
|
||||
response: str
|
||||
response_id: str
|
||||
timestamp: float
|
||||
|
||||
|
||||
class ChatHistory(TypedDict):
|
||||
"""Type definition for the chat history dictionary."""
|
||||
|
||||
__annotations__: dict[str, list[HistoryEntry]]
|
||||
|
||||
|
||||
class SupportedOpenAIModel(str, Enum):
|
||||
o4_mini = "o4-mini"
|
||||
o1 = "o1"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
gpt_4o = "gpt-4o"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
|
||||
|
||||
class OpenAIDalle2(ComfyNodeABC):
|
||||
"""
|
||||
@@ -172,7 +115,7 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-2"
|
||||
@@ -319,7 +262,7 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
quality="standard",
|
||||
size="1024x1024",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-3"
|
||||
@@ -457,12 +400,12 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "gpt-image-1"
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type = "application/json"
|
||||
content_type="application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
@@ -471,7 +414,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
if image is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
content_type = "multipart/form-data"
|
||||
content_type ="multipart/form-data"
|
||||
|
||||
batch_size = image.shape[0]
|
||||
|
||||
@@ -543,466 +486,17 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
class OpenAITextNode(ComfyNodeABC):
|
||||
"""
|
||||
Base class for OpenAI text generation nodes.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/text/OpenAI"
|
||||
API_NODE = True
|
||||
|
||||
|
||||
class OpenAIChatNode(OpenAITextNode):
|
||||
"""
|
||||
Node to generate text responses from an OpenAI model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the chat node with a new session ID and empty history."""
|
||||
self.current_session_id: str = str(uuid.uuid4())
|
||||
self.history: dict[str, list[HistoryEntry]] = {}
|
||||
self.previous_response_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text inputs to the model, used to generate a response.",
|
||||
},
|
||||
),
|
||||
"persist_context": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Persist chat context between calls (multi-turn conversation)",
|
||||
},
|
||||
),
|
||||
"model": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
OpenAICreateResponse,
|
||||
"model",
|
||||
enum_type=SupportedOpenAIModel,
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
"files": (
|
||||
"OPENAI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.",
|
||||
},
|
||||
),
|
||||
"advanced_options": (
|
||||
"OPENAI_CHAT_CONFIG",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate text responses from an OpenAI model."
|
||||
|
||||
def get_result_response(
|
||||
self,
|
||||
response_id: str,
|
||||
include: Optional[list[Includable]] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> OpenAIResponse:
|
||||
"""
|
||||
Retrieve a model response with the given ID from the OpenAI API.
|
||||
|
||||
Args:
|
||||
response_id (str): The ID of the response to retrieve.
|
||||
include (Optional[List[Includable]]): Additional fields to include
|
||||
in the response. See the `include` parameter for Response
|
||||
creation above for more information.
|
||||
|
||||
"""
|
||||
return PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{RESPONSES_ENDPOINT}/{response_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=OpenAIResponse,
|
||||
query_params={"include": include},
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed"],
|
||||
status_extractor=lambda response: response.status,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
def get_message_content_from_response(
|
||||
self, response: OpenAIResponse
|
||||
) -> list[OutputContent]:
|
||||
"""Extract message content from the API response."""
|
||||
for output in response.output:
|
||||
if output.root.type == "message":
|
||||
return output.root.content
|
||||
raise TypeError("No output message found in response")
|
||||
|
||||
def get_text_from_message_content(
|
||||
self, message_content: list[OutputContent]
|
||||
) -> str:
|
||||
"""Extract text content from message content."""
|
||||
for content_item in message_content:
|
||||
if content_item.root.type == "output_text":
|
||||
return str(content_item.root.text)
|
||||
return "No text output found in response"
|
||||
|
||||
def get_history_text(self, session_id: str) -> str:
|
||||
"""Convert the entire history for a given session to JSON string."""
|
||||
return json.dumps(self.history[session_id])
|
||||
|
||||
def display_history_on_node(self, session_id: str, node_id: str) -> None:
|
||||
"""Display formatted chat history on the node UI."""
|
||||
render_spec = {
|
||||
"node_id": node_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": self.get_history_text(session_id),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
def add_to_history(
|
||||
self, session_id: str, prompt: str, output_text: str, response_id: str
|
||||
) -> None:
|
||||
"""Add a new entry to the chat history."""
|
||||
if session_id not in self.history:
|
||||
self.history[session_id] = []
|
||||
self.history[session_id].append(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": response_id,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
def parse_output_text_from_response(self, response: OpenAIResponse) -> str:
|
||||
"""Extract text output from the API response."""
|
||||
message_contents = self.get_message_content_from_response(response)
|
||||
return self.get_text_from_message_content(message_contents)
|
||||
|
||||
def generate_new_session_id(self) -> str:
|
||||
"""Generate a new unique session ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def get_session_id(self, persist_context: bool) -> str:
|
||||
"""Get the current or generate a new session ID based on context persistence."""
|
||||
return (
|
||||
self.current_session_id
|
||||
if persist_context
|
||||
else self.generate_new_session_id()
|
||||
)
|
||||
|
||||
def tensor_to_input_image_content(
|
||||
self, image: torch.Tensor, detail_level: Detail = "auto"
|
||||
) -> InputImageContent:
|
||||
"""Convert a tensor to an input image content object."""
|
||||
return InputImageContent(
|
||||
detail=detail_level,
|
||||
image_url=f"data:image/png;base64,{tensor_to_base64_string(image)}",
|
||||
type="input_image",
|
||||
)
|
||||
|
||||
def create_input_message_contents(
|
||||
self,
|
||||
prompt: str,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[InputFileContent]] = None,
|
||||
) -> InputMessageContentList:
|
||||
"""Create a list of input message contents from prompt and optional image."""
|
||||
content_list: list[InputContent] = [
|
||||
InputTextContent(text=prompt, type="input_text"),
|
||||
]
|
||||
if image is not None:
|
||||
for i in range(image.shape[0]):
|
||||
content_list.append(
|
||||
self.tensor_to_input_image_content(image[i].unsqueeze(0))
|
||||
)
|
||||
if files is not None:
|
||||
content_list.extend(files)
|
||||
|
||||
return InputMessageContentList(
|
||||
root=content_list,
|
||||
)
|
||||
|
||||
def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]:
|
||||
"""Extract response ID from prompt if it exists."""
|
||||
parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt)
|
||||
return parsed_id.group(1) if parsed_id else None
|
||||
|
||||
def strip_response_tag_from_prompt(self, prompt: str) -> str:
|
||||
"""Remove the response ID tag from the prompt."""
|
||||
return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip())
|
||||
|
||||
def delete_history_after_response_id(
|
||||
self, new_start_id: str, session_id: str
|
||||
) -> None:
|
||||
"""Delete history entries after a specific response ID."""
|
||||
if session_id not in self.history:
|
||||
return
|
||||
|
||||
new_history = []
|
||||
i = 0
|
||||
while (
|
||||
i < len(self.history[session_id])
|
||||
and self.history[session_id][i]["response_id"] != new_start_id
|
||||
):
|
||||
new_history.append(self.history[session_id][i])
|
||||
i += 1
|
||||
|
||||
# Since it's the new starting point (not the response being edited), we include it as well
|
||||
if i < len(self.history[session_id]):
|
||||
new_history.append(self.history[session_id][i])
|
||||
|
||||
self.history[session_id] = new_history
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
persist_context: bool,
|
||||
model: SupportedOpenAIModel,
|
||||
unique_id: Optional[str] = None,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[InputFileContent]] = None,
|
||||
advanced_options: Optional[CreateModelResponseProperties] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
session_id = self.get_session_id(persist_context)
|
||||
response_id_override = self.parse_response_id_from_prompt(prompt)
|
||||
if response_id_override:
|
||||
is_starting_from_beginning = response_id_override == "start"
|
||||
if is_starting_from_beginning:
|
||||
self.history[session_id] = []
|
||||
previous_response_id = None
|
||||
else:
|
||||
previous_response_id = response_id_override
|
||||
self.delete_history_after_response_id(response_id_override, session_id)
|
||||
prompt = self.strip_response_tag_from_prompt(prompt)
|
||||
elif persist_context:
|
||||
previous_response_id = self.previous_response_id
|
||||
else:
|
||||
previous_response_id = None
|
||||
|
||||
# Create response
|
||||
create_response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=RESPONSES_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=OpenAICreateResponse,
|
||||
response_model=OpenAIResponse,
|
||||
),
|
||||
request=OpenAICreateResponse(
|
||||
input=[
|
||||
Item(
|
||||
root=InputMessage(
|
||||
content=self.create_input_message_contents(
|
||||
prompt, images, files
|
||||
),
|
||||
role="user",
|
||||
)
|
||||
),
|
||||
],
|
||||
store=True,
|
||||
stream=False,
|
||||
model=model,
|
||||
previous_response_id=previous_response_id,
|
||||
**(
|
||||
advanced_options.model_dump(exclude_none=True)
|
||||
if advanced_options
|
||||
else {}
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
response_id = create_response.id
|
||||
|
||||
# Get result output
|
||||
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||
output_text = self.parse_output_text_from_response(result_response)
|
||||
|
||||
# Update history
|
||||
self.add_to_history(session_id, prompt, output_text, response_id)
|
||||
self.display_history_on_node(session_id, unique_id)
|
||||
self.previous_response_id = response_id
|
||||
|
||||
return (output_text,)
|
||||
|
||||
|
||||
class OpenAIInputFiles(ComfyNodeABC):
|
||||
"""
|
||||
Loads and formats input files for OpenAI API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
"""
|
||||
For details about the supported file input types, see:
|
||||
https://platform.openai.com/docs/guides/pdf-files?api-mode=responses
|
||||
"""
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
input_files = [
|
||||
f
|
||||
for f in os.scandir(input_dir)
|
||||
if f.is_file()
|
||||
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
|
||||
and f.stat().st_size < 32 * 1024 * 1024
|
||||
]
|
||||
input_files = sorted(input_files, key=lambda x: x.name)
|
||||
input_files = [f.name for f in input_files]
|
||||
return {
|
||||
"required": {
|
||||
"file": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||
"options": input_files,
|
||||
"default": input_files[0] if input_files else None,
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"OPENAI_INPUT_FILES": (
|
||||
"OPENAI_INPUT_FILES",
|
||||
{
|
||||
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes."
|
||||
RETURN_TYPES = ("OPENAI_INPUT_FILES",)
|
||||
FUNCTION = "prepare_files"
|
||||
CATEGORY = "api node/text/OpenAI"
|
||||
|
||||
def create_input_file_content(self, file_path: str) -> InputFileContent:
|
||||
return InputFileContent(
|
||||
file_data=text_filepath_to_data_uri(file_path),
|
||||
filename=os.path.basename(file_path),
|
||||
type="input_file",
|
||||
)
|
||||
|
||||
def prepare_files(
|
||||
self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []
|
||||
) -> tuple[list[InputFileContent]]:
|
||||
"""
|
||||
Loads and formats input files for OpenAI API.
|
||||
"""
|
||||
file_path = folder_paths.get_annotated_filepath(file)
|
||||
input_file_content = self.create_input_file_content(file_path)
|
||||
files = [input_file_content] + OPENAI_INPUT_FILES
|
||||
return (files,)
|
||||
|
||||
|
||||
class OpenAIChatConfig(ComfyNodeABC):
|
||||
"""Allows setting additional configuration for the OpenAI Chat Node."""
|
||||
|
||||
RETURN_TYPES = ("OPENAI_CHAT_CONFIG",)
|
||||
FUNCTION = "configure"
|
||||
DESCRIPTION = (
|
||||
"Allows specifying advanced configuration options for the OpenAI Chat Nodes."
|
||||
)
|
||||
CATEGORY = "api node/text/OpenAI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"truncation": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["auto", "disabled"],
|
||||
"default": "auto",
|
||||
"tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"max_output_tokens": model_field_to_node_input(
|
||||
IO.INT,
|
||||
OpenAICreateResponse,
|
||||
"max_output_tokens",
|
||||
min=16,
|
||||
default=4096,
|
||||
max=16384,
|
||||
tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens",
|
||||
),
|
||||
"instructions": model_field_to_node_input(
|
||||
IO.STRING, OpenAICreateResponse, "instructions", multiline=True
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def configure(
|
||||
self,
|
||||
truncation: bool,
|
||||
instructions: Optional[str] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
) -> tuple[CreateModelResponseProperties]:
|
||||
"""
|
||||
Configure advanced options for the OpenAI Chat Node.
|
||||
|
||||
Note:
|
||||
While `top_p` and `temperature` are listed as properties in the
|
||||
spec, they are not supported for all models (e.g., o4-mini).
|
||||
They are not exposed as inputs at all to avoid having to manually
|
||||
remove depending on model choice.
|
||||
"""
|
||||
return (
|
||||
CreateModelResponseProperties(
|
||||
instructions=instructions,
|
||||
truncation=truncation,
|
||||
max_output_tokens=max_output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OpenAIDalle2": OpenAIDalle2,
|
||||
"OpenAIDalle3": OpenAIDalle3,
|
||||
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||
"OpenAIChatNode": OpenAIChatNode,
|
||||
"OpenAIInputFiles": OpenAIInputFiles,
|
||||
"OpenAIChatConfig": OpenAIChatConfig,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
"OpenAIChatNode": "OpenAI Chat",
|
||||
"OpenAIInputFiles": "OpenAI Chat Input Files",
|
||||
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
|
||||
}
|
||||
|
||||
@@ -6,42 +6,40 @@ Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
import numpy as np
|
||||
from comfy_api_nodes.apis import (
|
||||
IngredientsMode,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
PikaDurationEnum,
|
||||
Pikaffect,
|
||||
PikaGenerateResponse,
|
||||
PikaResolutionEnum,
|
||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
PikaVideoResponse,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
IngredientsMode,
|
||||
PikaDurationEnum,
|
||||
PikaResolutionEnum,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
Pikaffect,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
EmptyRequest,
|
||||
HttpMethod,
|
||||
PollingOperation,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_bytesio,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
@@ -206,7 +204,6 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -460,7 +457,7 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
|
||||
@@ -1,462 +0,0 @@
|
||||
"""
|
||||
ComfyUI X Rodin3D(Deemos) API Nodes
|
||||
|
||||
Rodin API docs: https://developer.hyper3d.ai/
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
import folder_paths as comfy_paths
|
||||
import requests
|
||||
import os
|
||||
import datetime
|
||||
import shutil
|
||||
import time
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
from PIL import Image
|
||||
from comfy_api_nodes.apis.rodin_api import (
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
Rodin3DCheckStatusRequest,
|
||||
Rodin3DCheckStatusResponse,
|
||||
Rodin3DDownloadRequest,
|
||||
Rodin3DDownloadResponse,
|
||||
JobStatus,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
)
|
||||
|
||||
|
||||
COMMON_PARAMETERS = {
|
||||
"Seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default":0,
|
||||
"min":0,
|
||||
"max":65535,
|
||||
"display":"number"
|
||||
}
|
||||
),
|
||||
"Material_Type": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["PBR", "Shaded"],
|
||||
"default": "PBR"
|
||||
}
|
||||
),
|
||||
"Polygon_count": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
|
||||
"default": "18K-Quad"
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
def create_task_error(response: Rodin3DGenerateResponse):
|
||||
"""Check if the response has error"""
|
||||
return hasattr(response, "error")
|
||||
|
||||
|
||||
|
||||
class Rodin3DAPI:
|
||||
"""
|
||||
Generate 3D Assets using Rodin API
|
||||
"""
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
RETURN_NAMES = ("3D Model Path",)
|
||||
CATEGORY = "api node/3d/Rodin"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
|
||||
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
|
||||
"""
|
||||
Converts a PyTorch tensor to a file-like object.
|
||||
|
||||
Args:
|
||||
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
|
||||
where C is the number of channels (3 for RGB), H is height, and W is width.
|
||||
|
||||
Returns:
|
||||
- io.BytesIO: A file-like object containing the image data.
|
||||
"""
|
||||
array = tensor.cpu().numpy()
|
||||
array = (array * 255).astype('uint8')
|
||||
image = Image.fromarray(array, 'RGB')
|
||||
|
||||
original_width, original_height = image.size
|
||||
original_pixels = original_width * original_height
|
||||
if original_pixels > max_pixels:
|
||||
scale = math.sqrt(max_pixels / original_pixels)
|
||||
new_width = int(original_width * scale)
|
||||
new_height = int(original_height * scale)
|
||||
else:
|
||||
new_width, new_height = original_width, original_height
|
||||
|
||||
if new_width != original_width or new_height != original_height:
|
||||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
|
||||
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
|
||||
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
||||
status_list = [str(job.status) for job in response.jobs]
|
||||
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
|
||||
if has_failed:
|
||||
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
|
||||
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
||||
elif all_done:
|
||||
return "DONE"
|
||||
else:
|
||||
return "Generating"
|
||||
|
||||
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
if images == None:
|
||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||
if len(images) >= 5:
|
||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||
|
||||
path = "/proxy/rodin/api/v2/rodin"
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DGenerateRequest,
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
),
|
||||
request=Rodin3DGenerateRequest(
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
material=material,
|
||||
quality=quality,
|
||||
mesh_mode=mesh_mode
|
||||
),
|
||||
files=[
|
||||
(
|
||||
"images",
|
||||
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
|
||||
)
|
||||
for image in images if image is not None
|
||||
],
|
||||
content_type = "multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
if create_task_error(response):
|
||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
||||
subscription_key = response.jobs.subscription_key
|
||||
task_uuid = response.uuid
|
||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
||||
return task_uuid, subscription_key
|
||||
|
||||
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||
|
||||
path = "/proxy/rodin/api/v2/status"
|
||||
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path = path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DCheckStatusRequest,
|
||||
response_model=Rodin3DCheckStatusResponse,
|
||||
),
|
||||
request=Rodin3DCheckStatusRequest(
|
||||
subscription_key = subscription_key
|
||||
),
|
||||
completed_statuses=["DONE"],
|
||||
failed_statuses=["FAILED"],
|
||||
status_extractor=self.check_rodin_status,
|
||||
poll_interval=3.0,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||
|
||||
return poll_operation.execute()
|
||||
|
||||
|
||||
|
||||
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
|
||||
path = "/proxy/rodin/api/v2/download"
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DDownloadRequest,
|
||||
response_model=Rodin3DDownloadResponse,
|
||||
),
|
||||
request=Rodin3DDownloadRequest(
|
||||
task_uuid=uuid
|
||||
),
|
||||
auth_kwargs=kwargs
|
||||
)
|
||||
|
||||
return operation.execute()
|
||||
|
||||
def GetQualityAndMode(self, PolyCount):
|
||||
if PolyCount == "200K-Triangle":
|
||||
mesh_mode = "Raw"
|
||||
quality = "medium"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
if PolyCount == "4K-Quad":
|
||||
quality = "extra-low"
|
||||
elif PolyCount == "8K-Quad":
|
||||
quality = "low"
|
||||
elif PolyCount == "18K-Quad":
|
||||
quality = "medium"
|
||||
elif PolyCount == "50K-Quad":
|
||||
quality = "high"
|
||||
else:
|
||||
quality = "medium"
|
||||
|
||||
return mesh_mode, quality
|
||||
|
||||
def DownLoadFiles(self, Url_List):
|
||||
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(Save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
for Item in Url_List.list:
|
||||
url = Item.url
|
||||
file_name = Item.name
|
||||
file_path = os.path.join(Save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
time.sleep(2)
|
||||
else:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
|
||||
|
||||
return model_file_path
|
||||
|
||||
|
||||
class Rodin3D_Regular(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
**COMMON_PARAMETERS
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Regular"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
|
||||
return (model,)
|
||||
|
||||
class Rodin3D_Detail(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
**COMMON_PARAMETERS
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Detail"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
|
||||
return (model,)
|
||||
|
||||
class Rodin3D_Smooth(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
**COMMON_PARAMETERS
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Smooth"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
|
||||
return (model,)
|
||||
|
||||
class Rodin3D_Sketch(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
"Seed":
|
||||
(
|
||||
IO.INT,
|
||||
{
|
||||
"default":0,
|
||||
"min":0,
|
||||
"max":65535,
|
||||
"display":"number"
|
||||
}
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Sketch"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
material_type = "PBR"
|
||||
quality = "medium"
|
||||
mesh_mode = "Quad"
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
|
||||
return (model,)
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Rodin3D_Regular": Rodin3D_Regular,
|
||||
"Rodin3D_Detail": Rodin3D_Detail,
|
||||
"Rodin3D_Smooth": Rodin3D_Smooth,
|
||||
"Rodin3D_Sketch": Rodin3D_Sketch,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
|
||||
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
|
||||
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
|
||||
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
|
||||
}
|
||||
@@ -1,635 +0,0 @@
|
||||
"""Runway API Nodes
|
||||
|
||||
API Docs:
|
||||
- https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete
|
||||
|
||||
User Guides:
|
||||
- https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha
|
||||
- https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video
|
||||
- https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo
|
||||
- https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
RunwayTaskStatusResponse as TaskStatusResponse,
|
||||
RunwayTaskStatusEnum as TaskStatus,
|
||||
RunwayModelEnum as Model,
|
||||
RunwayDurationEnum as Duration,
|
||||
RunwayAspectRatioEnum as AspectRatio,
|
||||
RunwayPromptImageObject,
|
||||
RunwayPromptImageDetailedObject,
|
||||
RunwayTextToImageRequest,
|
||||
RunwayTextToImageResponse,
|
||||
Model4,
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
download_url_to_image_tensor,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||
|
||||
AVERAGE_DURATION_I2V_SECONDS = 64
|
||||
AVERAGE_DURATION_FLF_SECONDS = 256
|
||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||
|
||||
|
||||
class RunwayApiError(Exception):
|
||||
"""Base exception for Runway API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||
|
||||
field_1280_720 = "1280:720"
|
||||
field_720_1280 = "720:1280"
|
||||
field_1104_832 = "1104:832"
|
||||
field_832_1104 = "832:1104"
|
||||
field_960_960 = "960:960"
|
||||
field_1584_672 = "1584:672"
|
||||
|
||||
|
||||
class RunwayGen3aAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen3a_turbo model."""
|
||||
|
||||
field_768_1280 = "768:1280"
|
||||
field_1280_768 = "1280:768"
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if response.output and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
return None
|
||||
|
||||
|
||||
# TODO: replace with updated image validation utils (upstream)
|
||||
def validate_input_image(image: torch.Tensor) -> bool:
|
||||
"""
|
||||
Validate the input image is within the size limits for the Runway API.
|
||||
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
|
||||
"""
|
||||
return image.shape[2] < 8000 and image.shape[1] < 8000
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
],
|
||||
failed_statuses=[
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
],
|
||||
status_extractor=lambda response: (response.status.value),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
).execute()
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if response.output and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
return None
|
||||
|
||||
|
||||
class RunwayVideoGenNode(ComfyNodeABC):
|
||||
"""Runway Video Node Base."""
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/video/Runway"
|
||||
API_NODE = True
|
||||
|
||||
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
|
||||
"""
|
||||
Validate the task creation response from the Runway API matches
|
||||
expected format.
|
||||
"""
|
||||
if not bool(response.id):
|
||||
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||
return True
|
||||
|
||||
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
|
||||
"""
|
||||
Validate the successful task status response from the Runway API
|
||||
matches expected format.
|
||||
"""
|
||||
if not response.output or len(response.output) == 0:
|
||||
raise RunwayApiError(
|
||||
"Runway task succeeded but no video data found in response."
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
final_response = self.get_response(task_id, auth_kwargs, node_id)
|
||||
self.validate_response(final_response)
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return (download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
"""Runway Image to Video Node using Gen3a Turbo model."""
|
||||
|
||||
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayGen3aAspectRatio,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
"seed",
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
"""Runway Image to Video Node using Gen4 Turbo model."""
|
||||
|
||||
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayGen4TurboAspectRatio,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
"seed",
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen4_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
)
|
||||
|
||||
|
||||
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
"""Runway First-Last Frame Node."""
|
||||
|
||||
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
),
|
||||
"end_frame": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
|
||||
},
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayGen3aAspectRatio,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
"seed",
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
validate_input_image(end_frame)
|
||||
|
||||
# Upload images
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
)
|
||||
|
||||
|
||||
class RunwayTextToImageNode(ComfyNodeABC):
|
||||
"""Runway Text to Image Node."""
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Runway"
|
||||
API_NODE = True
|
||||
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayTextToImageRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayTextToImageAspectRatioEnum,
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"reference_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Optional reference image to guide the generation"},
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
|
||||
"""
|
||||
Validate the task creation response from the Runway API matches
|
||||
expected format.
|
||||
"""
|
||||
if not bool(response.id):
|
||||
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||
return True
|
||||
|
||||
def validate_response(self, response: TaskStatusResponse) -> bool:
|
||||
"""
|
||||
Validate the successful task status response from the Runway API
|
||||
matches expected format.
|
||||
"""
|
||||
if not response.output or len(response.output) == 0:
|
||||
raise RunwayApiError(
|
||||
"Runway task succeeded but no image data found in response."
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
# Prepare reference images if provided
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_input_image(reference_image)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload reference image to comfy api.")
|
||||
|
||||
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||
|
||||
# Create request
|
||||
request = RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
ratio=ratio,
|
||||
referenceImages=reference_images,
|
||||
)
|
||||
|
||||
# Execute initial request
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_IMAGE,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayTextToImageRequest,
|
||||
response_model=RunwayTextToImageResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
# Poll for completion
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
self.validate_response(final_response)
|
||||
|
||||
# Download and return image
|
||||
image_url = get_image_url_from_task_status(final_response)
|
||||
return (download_url_to_image_tensor(image_url),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
|
||||
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
|
||||
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
|
||||
"RunwayTextToImageNode": RunwayTextToImageNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
|
||||
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
|
||||
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
|
||||
"RunwayTextToImageNode": "Runway Text to Image",
|
||||
}
|
||||
@@ -1,574 +0,0 @@
|
||||
import os
|
||||
from folder_paths import get_output_directory
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api_nodes.apis import (
|
||||
TripoOrientation,
|
||||
TripoModelVersion,
|
||||
)
|
||||
from comfy_api_nodes.apis.tripo_api import (
|
||||
TripoTaskType,
|
||||
TripoStyle,
|
||||
TripoFileReference,
|
||||
TripoFileEmptyReference,
|
||||
TripoUrlReference,
|
||||
TripoTaskResponse,
|
||||
TripoTaskStatus,
|
||||
TripoTextToModelRequest,
|
||||
TripoImageToModelRequest,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoRefineModelRequest,
|
||||
TripoAnimateRigRequest,
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoConvertModelRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_bytesio,
|
||||
)
|
||||
|
||||
|
||||
def upload_image_to_tripo(image, **kwargs):
|
||||
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
|
||||
|
||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
if response.data is not None:
|
||||
for key in ["pbr_model", "model", "base_model"]:
|
||||
if getattr(response.data.output, key, None) is not None:
|
||||
return getattr(response.data.output, key)
|
||||
raise RuntimeError(f"Failed to get model url from response: {response}")
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
kwargs: dict[str, str],
|
||||
response: TripoTaskResponse,
|
||||
) -> tuple[str, str]:
|
||||
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to generate mesh: {response.error}")
|
||||
task_id = response.data.task_id
|
||||
response_poll = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
completed_statuses=[TripoTaskStatus.SUCCESS],
|
||||
failed_statuses=[
|
||||
TripoTaskStatus.FAILED,
|
||||
TripoTaskStatus.CANCELLED,
|
||||
TripoTaskStatus.UNKNOWN,
|
||||
TripoTaskStatus.BANNED,
|
||||
TripoTaskStatus.EXPIRED,
|
||||
],
|
||||
status_extractor=lambda x: x.data.status,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=kwargs["unique_id"],
|
||||
result_url_extractor=get_model_url_from_response,
|
||||
progress_extractor=lambda x: x.data.progress,
|
||||
).execute()
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
url = get_model_url_from_response(response_poll)
|
||||
bytesio = download_url_to_bytesio(url)
|
||||
# Save the downloaded model file
|
||||
model_file = f"tripo_model_{task_id}.glb"
|
||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||
f.write(bytesio.getvalue())
|
||||
return model_file, task_id
|
||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||
|
||||
class TripoTextToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on a text prompt using Tripo's API.
|
||||
"""
|
||||
AVERAGE_DURATION = 80
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": ("STRING", {"multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": ("STRING", {"multiline": True}),
|
||||
"model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
|
||||
"texture": ("BOOLEAN", {"default": True}),
|
||||
"pbr": ("BOOLEAN", {"default": True}),
|
||||
"image_seed": ("INT", {"default": 42}),
|
||||
"model_seed": ("INT", {"default": 42}),
|
||||
"texture_seed": ("INT", {"default": 42}),
|
||||
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||
"quad": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if not prompt:
|
||||
raise RuntimeError("Prompt is required")
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoTextToModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoTextToModelRequest(
|
||||
type=TripoTaskType.TEXT_TO_MODEL,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
model_version=model_version,
|
||||
style=style_enum,
|
||||
texture=texture,
|
||||
pbr=pbr,
|
||||
image_seed=image_seed,
|
||||
model_seed=model_seed,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
auto_size=True,
|
||||
quad=quad
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoImageToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on a single image using Tripo's API.
|
||||
"""
|
||||
AVERAGE_DURATION = 80
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
|
||||
"texture": ("BOOLEAN", {"default": True}),
|
||||
"pbr": ("BOOLEAN", {"default": True}),
|
||||
"model_seed": ("INT", {"default": 42}),
|
||||
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
|
||||
"texture_seed": ("INT", {"default": 42}),
|
||||
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||
"quad": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if image is None:
|
||||
raise RuntimeError("Image is required")
|
||||
tripo_file = upload_image_to_tripo(image, **kwargs)
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoImageToModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoImageToModelRequest(
|
||||
type=TripoTaskType.IMAGE_TO_MODEL,
|
||||
file=tripo_file,
|
||||
model_version=model_version,
|
||||
style=style_enum,
|
||||
texture=texture,
|
||||
pbr=pbr,
|
||||
model_seed=model_seed,
|
||||
orientation=orientation,
|
||||
texture_alignment=texture_alignment,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
auto_size=True,
|
||||
quad=quad
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoMultiviewToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API.
|
||||
"""
|
||||
AVERAGE_DURATION = 80
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"image_left": ("IMAGE",),
|
||||
"image_back": ("IMAGE",),
|
||||
"image_right": ("IMAGE",),
|
||||
"model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
|
||||
"texture": ("BOOLEAN", {"default": True}),
|
||||
"pbr": ("BOOLEAN", {"default": True}),
|
||||
"model_seed": ("INT", {"default": 42}),
|
||||
"texture_seed": ("INT", {"default": 42}),
|
||||
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||
"quad": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||
if image is None:
|
||||
raise RuntimeError("front image for multiview is required")
|
||||
images = []
|
||||
image_dict = {
|
||||
"image": image,
|
||||
"image_left": image_left,
|
||||
"image_back": image_back,
|
||||
"image_right": image_right
|
||||
}
|
||||
if image_left is None and image_back is None and image_right is None:
|
||||
raise RuntimeError("At least one of left, back, or right image must be provided for multiview")
|
||||
for image_name in ["image", "image_left", "image_back", "image_right"]:
|
||||
image_ = image_dict[image_name]
|
||||
if image_ is not None:
|
||||
tripo_file = upload_image_to_tripo(image_, **kwargs)
|
||||
images.append(tripo_file)
|
||||
else:
|
||||
images.append(TripoFileEmptyReference())
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoMultiviewToModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoMultiviewToModelRequest(
|
||||
type=TripoTaskType.MULTIVIEW_TO_MODEL,
|
||||
files=images,
|
||||
model_version=model_version,
|
||||
orientation=orientation,
|
||||
texture=texture,
|
||||
pbr=pbr,
|
||||
model_seed=model_seed,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
face_limit=face_limit,
|
||||
quad=quad,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoTextureNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_task_id": ("MODEL_TASK_ID",),
|
||||
},
|
||||
"optional": {
|
||||
"texture": ("BOOLEAN", {"default": True}),
|
||||
"pbr": ("BOOLEAN", {"default": True}),
|
||||
"texture_seed": ("INT", {"default": 42}),
|
||||
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 80
|
||||
|
||||
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoTextureModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoTextureModelRequest(
|
||||
original_model_task_id=model_task_id,
|
||||
texture=texture,
|
||||
pbr=pbr,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
texture_alignment=texture_alignment
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRefineNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_task_id": ("MODEL_TASK_ID", {
|
||||
"tooltip": "Must be a v1.4 Tripo model"
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only."
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 240
|
||||
|
||||
def generate_mesh(self, model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoRefineModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoRefineModelRequest(
|
||||
draft_model_task_id=model_task_id
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRigNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"original_model_task_id": ("MODEL_TASK_ID",),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "RIG_TASK_ID")
|
||||
RETURN_NAMES = ("model_file", "rig task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 180
|
||||
|
||||
def generate_mesh(self, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoAnimateRigRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoAnimateRigRequest(
|
||||
original_model_task_id=original_model_task_id,
|
||||
out_format="glb",
|
||||
spec="tripo"
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoRetargetNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"original_model_task_id": ("RIG_TASK_ID",),
|
||||
"animation": ([
|
||||
"preset:idle",
|
||||
"preset:walk",
|
||||
"preset:climb",
|
||||
"preset:jump",
|
||||
"preset:slash",
|
||||
"preset:shoot",
|
||||
"preset:hurt",
|
||||
"preset:fall",
|
||||
"preset:turn",
|
||||
],),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "RETARGET_TASK_ID")
|
||||
RETURN_NAMES = ("model_file", "retarget task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoAnimateRetargetRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoAnimateRetargetRequest(
|
||||
original_model_task_id=original_model_task_id,
|
||||
animation=animation,
|
||||
out_format="glb",
|
||||
bake_animation=True
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoConversionNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",),
|
||||
"format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],),
|
||||
},
|
||||
"optional": {
|
||||
"quad": ("BOOLEAN", {"default": False}),
|
||||
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||
"texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}),
|
||||
"texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"})
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input_types):
|
||||
# The min and max of input1 and input2 are still validated because
|
||||
# we didn't take `input1` or `input2` as arguments
|
||||
if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"):
|
||||
return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type"
|
||||
return True
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||
if not original_model_task_id:
|
||||
raise RuntimeError("original_model_task_id is required")
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoConvertModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoConvertModelRequest(
|
||||
original_model_task_id=original_model_task_id,
|
||||
format=format,
|
||||
quad=quad if quad else None,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
texture_size=texture_size if texture_size != 4096 else None,
|
||||
texture_format=texture_format if texture_format != "JPEG" else None
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripoTextToModelNode": TripoTextToModelNode,
|
||||
"TripoImageToModelNode": TripoImageToModelNode,
|
||||
"TripoMultiviewToModelNode": TripoMultiviewToModelNode,
|
||||
"TripoTextureNode": TripoTextureNode,
|
||||
"TripoRefineNode": TripoRefineNode,
|
||||
"TripoRigNode": TripoRigNode,
|
||||
"TripoRetargetNode": TripoRetargetNode,
|
||||
"TripoConversionNode": TripoConversionNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TripoTextToModelNode": "Tripo: Text to Model",
|
||||
"TripoImageToModelNode": "Tripo: Image to Model",
|
||||
"TripoMultiviewToModelNode": "Tripo: Multiview to Model",
|
||||
"TripoTextureNode": "Tripo: Texture model",
|
||||
"TripoRefineNode": "Tripo: Refine Draft model",
|
||||
"TripoRigNode": "Tripo: Rig model",
|
||||
"TripoRetargetNode": "Tripo: Retarget rigged model",
|
||||
"TripoConversionNode": "Tripo: Convert model",
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
@@ -230,186 +229,6 @@ class SVG:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
||||
|
||||
class ImageStitch:
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||
"spacing_width": (
|
||||
"INT",
|
||||
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||
),
|
||||
"spacing_color": (
|
||||
["white", "black", "red", "green", "blue"],
|
||||
{"default": "white"},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stitch"
|
||||
CATEGORY = "image/transform"
|
||||
DESCRIPTION = """
|
||||
Stitches image2 to image1 in the specified direction.
|
||||
If image2 is not provided, returns image1 unchanged.
|
||||
Optional spacing can be added between images.
|
||||
"""
|
||||
|
||||
def stitch(
|
||||
self,
|
||||
image1,
|
||||
direction,
|
||||
match_image_size,
|
||||
spacing_width,
|
||||
spacing_color,
|
||||
image2=None,
|
||||
):
|
||||
if image2 is None:
|
||||
return (image1,)
|
||||
|
||||
# Handle batch size differences
|
||||
if image1.shape[0] != image2.shape[0]:
|
||||
max_batch = max(image1.shape[0], image2.shape[0])
|
||||
if image1.shape[0] < max_batch:
|
||||
image1 = torch.cat(
|
||||
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
||||
)
|
||||
if image2.shape[0] < max_batch:
|
||||
image2 = torch.cat(
|
||||
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
||||
)
|
||||
|
||||
# Match image sizes if requested
|
||||
if match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
aspect_ratio = w2 / h2
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
target_h, target_w = h1, int(h1 * aspect_ratio)
|
||||
else: # up, down
|
||||
target_w, target_h = w1, int(w1 / aspect_ratio)
|
||||
|
||||
image2 = comfy.utils.common_upscale(
|
||||
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
||||
).movedim(1, -1)
|
||||
|
||||
# When not matching sizes, pad to align non-concat dimensions
|
||||
if not match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
# For horizontal concat, pad heights to match
|
||||
if h1 != h2:
|
||||
target_h = max(h1, h2)
|
||||
if h1 < target_h:
|
||||
pad_h = target_h - h1
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||
if h2 < target_h:
|
||||
pad_h = target_h - h2
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||
else: # up, down
|
||||
# For vertical concat, pad widths to match
|
||||
if w1 != w2:
|
||||
target_w = max(w1, w2)
|
||||
if w1 < target_w:
|
||||
pad_w = target_w - w1
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||
if w2 < target_w:
|
||||
pad_w = target_w - w2
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||
|
||||
# Ensure same number of channels
|
||||
if image1.shape[-1] != image2.shape[-1]:
|
||||
max_channels = max(image1.shape[-1], image2.shape[-1])
|
||||
if image1.shape[-1] < max_channels:
|
||||
image1 = torch.cat(
|
||||
[
|
||||
image1,
|
||||
torch.ones(
|
||||
*image1.shape[:-1],
|
||||
max_channels - image1.shape[-1],
|
||||
device=image1.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if image2.shape[-1] < max_channels:
|
||||
image2 = torch.cat(
|
||||
[
|
||||
image2,
|
||||
torch.ones(
|
||||
*image2.shape[:-1],
|
||||
max_channels - image2.shape[-1],
|
||||
device=image2.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Add spacing if specified
|
||||
if spacing_width > 0:
|
||||
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
||||
|
||||
color_map = {
|
||||
"white": 1.0,
|
||||
"black": 0.0,
|
||||
"red": (1.0, 0.0, 0.0),
|
||||
"green": (0.0, 1.0, 0.0),
|
||||
"blue": (0.0, 0.0, 1.0),
|
||||
}
|
||||
color_val = color_map[spacing_color]
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
max(image1.shape[1], image2.shape[1]),
|
||||
spacing_width,
|
||||
image1.shape[-1],
|
||||
)
|
||||
else:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
spacing_width,
|
||||
max(image1.shape[2], image2.shape[2]),
|
||||
image1.shape[-1],
|
||||
)
|
||||
|
||||
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
||||
if isinstance(color_val, tuple):
|
||||
for i, c in enumerate(color_val):
|
||||
if i < spacing.shape[-1]:
|
||||
spacing[..., i] = c
|
||||
if spacing.shape[-1] == 4: # Add alpha
|
||||
spacing[..., 3] = 1.0
|
||||
else:
|
||||
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
||||
if spacing.shape[-1] == 4:
|
||||
spacing[..., 3] = 1.0
|
||||
|
||||
# Concatenate images
|
||||
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
||||
if spacing_width > 0:
|
||||
images.insert(1, spacing)
|
||||
|
||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||
return (torch.cat(images, dim=concat_dim),)
|
||||
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
@@ -499,5 +318,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
"ImageStitch": ImageStitch,
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ class Load3D():
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
|
||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))]
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
|
||||
@@ -296,41 +296,6 @@ class RegexExtract():
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class RegexReplace():
|
||||
DESCRIPTION = "Find and replace text using regex patterns."
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
||||
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||
return result,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StringConcatenate": StringConcatenate,
|
||||
"StringSubstring": StringSubstring,
|
||||
@@ -341,8 +306,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"StringContains": StringContains,
|
||||
"StringCompare": StringCompare,
|
||||
"RegexMatch": RegexMatch,
|
||||
"RegexExtract": RegexExtract,
|
||||
"RegexReplace": RegexReplace,
|
||||
"RegexExtract": RegexExtract
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -355,6 +319,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StringContains": "Contains",
|
||||
"StringCompare": "Compare",
|
||||
"RegexMatch": "Regex Match",
|
||||
"RegexExtract": "Regex Extract",
|
||||
"RegexReplace": "Regex Replace",
|
||||
"RegexExtract": "Regex Extract"
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||
|
||||
import torch
|
||||
|
||||
class TorchCompileModel:
|
||||
@classmethod
|
||||
@@ -15,7 +14,7 @@ class TorchCompileModel:
|
||||
|
||||
def patch(self, model, backend):
|
||||
m = model.clone()
|
||||
set_torch_compile_wrapper(model=m, backend=backend)
|
||||
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import torch
|
||||
|
||||
from comfy_api.v3.io import (
|
||||
ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay,
|
||||
IntegerInput, MaskInput, ImageInput, ComboDynamicInput, NodeOutput,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class V3TestNode(ComfyNodeV3):
|
||||
|
||||
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
return SchemaV3(
|
||||
node_id="V3TestNode1",
|
||||
display_name="V3 Test Node (1djekjd)",
|
||||
description="This is a funky V3 node test.",
|
||||
category="v3 nodes",
|
||||
inputs=[
|
||||
IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
|
||||
MaskInput("mask", behavior=InputBehavior.optional),
|
||||
ImageInput("image", display_name="new_image"),
|
||||
# IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider, ),
|
||||
# ComboDynamicInput("mask", behavior=InputBehavior.optional),
|
||||
# IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider,
|
||||
# dependent_inputs=[ComboDynamicInput("mask", behavior=InputBehavior.optional)],
|
||||
# dependent_values=[lambda my_value: IO.STRING if my_value < 5 else IO.NUMBER],
|
||||
# ),
|
||||
# ["option1", "option2". "option3"]
|
||||
# ComboDynamicInput["sdfgjhl", [ComboDynamicOptions("option1", [IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider, ImageInput(), MaskInput(), String()]),
|
||||
# CombyDynamicOptons("option2", [])
|
||||
# ]]
|
||||
],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs):
|
||||
a = NodeOutput(1)
|
||||
aa = NodeOutput(1, "hellothere")
|
||||
ab = NodeOutput(1, "hellothere", ui={"lol": "jk"})
|
||||
b = NodeOutput()
|
||||
c = NodeOutput(ui={"lol": "jk"})
|
||||
return NodeOutput()
|
||||
return NodeOutput(1)
|
||||
return NodeOutput(1, block_execution="Kill yourself")
|
||||
return ()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
NODES_LIST: list[ComfyNodeV3] = [
|
||||
V3TestNode,
|
||||
]
|
||||
|
||||
|
||||
# NODE_CLASS_MAPPINGS = {}
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
# for node in NODES_LIST:
|
||||
# schema = node.GET_SCHEMA()
|
||||
# NODE_CLASS_MAPPINGS[schema.node_id] = node
|
||||
# if schema.display_name:
|
||||
# NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
@@ -268,9 +268,8 @@ class WanVaceToVideo:
|
||||
trim_latent = reference_image.shape[2]
|
||||
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||
|
||||
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
@@ -345,44 +344,6 @@ class WanCameraImageToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
class WanPhantomSubjectToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"images": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond2 = negative
|
||||
if images is not None:
|
||||
images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
latent_images = []
|
||||
for i in images:
|
||||
latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])]
|
||||
concat_latent_image = torch.cat(latent_images, dim=2)
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image})
|
||||
cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, cond2, negative, out_latent)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
@@ -391,5 +352,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanVaceToVideo": WanVaceToVideo,
|
||||
"TrimVideoLatent": TrimVideoLatent,
|
||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.39"
|
||||
__version__ = "0.3.34"
|
||||
|
||||
26
execution.py
26
execution.py
@@ -17,7 +17,6 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_api.v3.io import NodeOutput
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@@ -243,22 +242,6 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
elif isinstance(r, NodeOutput):
|
||||
if r.ui is not None:
|
||||
uis.append(r.ui.as_dict())
|
||||
if r.expand is not None:
|
||||
has_subgraph = True
|
||||
new_graph = r.expand
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif r.result is not None:
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
@@ -926,6 +909,7 @@ class PromptQueue:
|
||||
self.currently_running = {}
|
||||
self.history = {}
|
||||
self.flags = {}
|
||||
server.prompt_queue = self
|
||||
|
||||
def put(self, item):
|
||||
with self.mutex:
|
||||
@@ -970,7 +954,6 @@ class PromptQueue:
|
||||
self.history[prompt[1]].update(history_result)
|
||||
self.server.queue_updated()
|
||||
|
||||
# Note: slow
|
||||
def get_current_queue(self):
|
||||
with self.mutex:
|
||||
out = []
|
||||
@@ -978,13 +961,6 @@ class PromptQueue:
|
||||
out += [x]
|
||||
return (out, copy.deepcopy(self.queue))
|
||||
|
||||
# read-safe as long as queue items are immutable
|
||||
def get_current_queue_volatile(self):
|
||||
with self.mutex:
|
||||
running = [x for x in self.currently_running.values()]
|
||||
queued = copy.copy(self.queue)
|
||||
return (running, queued)
|
||||
|
||||
def get_tasks_remaining(self):
|
||||
with self.mutex:
|
||||
return len(self.queue) + len(self.currently_running)
|
||||
|
||||
3
main.py
3
main.py
@@ -260,6 +260,7 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
q = execution.PromptQueue(prompt_server)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||
@@ -270,7 +271,7 @@ def start_comfyui(asyncio_loop=None):
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
@@ -5,18 +5,12 @@ from comfy.cli_args import args
|
||||
|
||||
from PIL import ImageFile, UnidentifiedImageError
|
||||
|
||||
def conditioning_set_values(conditioning, values={}, append=False):
|
||||
def conditioning_set_values(conditioning, values={}):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
val = values[k]
|
||||
if append:
|
||||
old_val = n[1].get(k, None)
|
||||
if old_val is not None:
|
||||
val = old_val + val
|
||||
|
||||
n[1][k] = val
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
32
nodes.py
32
nodes.py
@@ -26,7 +26,6 @@ import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.controlnet
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||
from comfy_api.v3.io import ComfyNodeV3
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
@@ -1104,7 +1103,16 @@ class unCLIPConditioning:
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"unclip_conditioning": [{"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}]}, append=True)
|
||||
c = []
|
||||
for t in conditioning:
|
||||
o = t[1].copy()
|
||||
x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
|
||||
if "unclip_conditioning" in o:
|
||||
o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
|
||||
else:
|
||||
o["unclip_conditioning"] = [x]
|
||||
n = [t[0], o]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class GLIGENLoader:
|
||||
@@ -2062,7 +2070,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||
"ImageBatch": "Batch Images",
|
||||
"ImageCrop": "Image Crop",
|
||||
"ImageStitch": "Image Stitch",
|
||||
"ImageBlend": "Image Blend",
|
||||
"ImageBlur": "Image Blur",
|
||||
"ImageQuantize": "Image Quantize",
|
||||
@@ -2130,7 +2137,6 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
||||
if os.path.isdir(web_dir):
|
||||
EXTENSION_WEB_DIRS[module_name] = web_dir
|
||||
|
||||
# V1 node definition
|
||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||
if name not in ignore:
|
||||
@@ -2139,19 +2145,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
# V3 node definition
|
||||
elif getattr(module, "NODES_LIST", None) is not None:
|
||||
for node_cls in module.NODES_LIST:
|
||||
node_cls: ComfyNodeV3
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
if schema.node_id not in ignore:
|
||||
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
|
||||
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||
if schema.display_name is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
return True
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
@@ -2271,7 +2266,6 @@ def init_builtin_extra_nodes():
|
||||
"nodes_ace.py",
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_v3_test.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@@ -2296,10 +2290,6 @@ def init_builtin_api_nodes():
|
||||
"nodes_pixverse.py",
|
||||
"nodes_stability.py",
|
||||
"nodes_pika.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_tripo.py",
|
||||
"nodes_rodin.py",
|
||||
"nodes_gemini.py",
|
||||
]
|
||||
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
|
||||
904
openapi.yaml
Normal file
904
openapi.yaml
Normal file
@@ -0,0 +1,904 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: ComfyUI API
|
||||
description: |
|
||||
API for ComfyUI - A powerful and modular UI for Stable Diffusion.
|
||||
|
||||
This API allows you to interact with ComfyUI programmatically, including:
|
||||
- Submitting workflows for execution
|
||||
- Managing the execution queue
|
||||
- Retrieving generated images
|
||||
- Managing models
|
||||
- Retrieving node information
|
||||
version: 1.0.0
|
||||
license:
|
||||
name: GNU General Public License v3.0
|
||||
url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE
|
||||
|
||||
servers:
|
||||
- url: /
|
||||
description: Default ComfyUI server
|
||||
|
||||
tags:
|
||||
- name: workflow
|
||||
description: Workflow execution and management
|
||||
- name: queue
|
||||
description: Queue management
|
||||
- name: image
|
||||
description: Image handling
|
||||
- name: node
|
||||
description: Node information
|
||||
- name: model
|
||||
description: Model management
|
||||
- name: system
|
||||
description: System information
|
||||
- name: internal
|
||||
description: Internal API routes
|
||||
|
||||
paths:
|
||||
/prompt:
|
||||
get:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Get information about current prompt execution
|
||||
description: Returns information about the current prompt in the execution queue
|
||||
operationId: getPromptInfo
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PromptInfo'
|
||||
post:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Submit a workflow for execution
|
||||
description: |
|
||||
Submit a workflow to be executed by the backend.
|
||||
The workflow is a JSON object describing the nodes and their connections.
|
||||
operationId: executePrompt
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PromptRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Success - Prompt accepted
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PromptResponse'
|
||||
'400':
|
||||
description: Invalid prompt
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
|
||||
/queue:
|
||||
get:
|
||||
tags:
|
||||
- queue
|
||||
summary: Get queue information
|
||||
description: Returns information about running and pending items in the queue
|
||||
operationId: getQueueInfo
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/QueueInfo'
|
||||
post:
|
||||
tags:
|
||||
- queue
|
||||
summary: Manage queue
|
||||
description: Clear the queue or delete specific items
|
||||
operationId: manageQueue
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
clear:
|
||||
type: boolean
|
||||
description: If true, clears the entire queue
|
||||
delete:
|
||||
type: array
|
||||
description: Array of prompt IDs to delete from the queue
|
||||
items:
|
||||
type: string
|
||||
format: uuid
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/interrupt:
|
||||
post:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Interrupt the current execution
|
||||
description: Interrupts the currently running workflow execution
|
||||
operationId: interruptExecution
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/free:
|
||||
post:
|
||||
tags:
|
||||
- system
|
||||
summary: Free resources
|
||||
description: Unload models and/or free memory
|
||||
operationId: freeResources
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
unload_models:
|
||||
type: boolean
|
||||
description: If true, unloads models from memory
|
||||
free_memory:
|
||||
type: boolean
|
||||
description: If true, frees GPU memory
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/history:
|
||||
get:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Get execution history
|
||||
description: Returns the history of executed workflows
|
||||
operationId: getHistory
|
||||
parameters:
|
||||
- name: max_items
|
||||
in: query
|
||||
description: Maximum number of history items to return
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
format: int32
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/HistoryItem'
|
||||
post:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Manage history
|
||||
description: Clear history or delete specific items
|
||||
operationId: manageHistory
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
clear:
|
||||
type: boolean
|
||||
description: If true, clears the entire history
|
||||
delete:
|
||||
type: array
|
||||
description: Array of prompt IDs to delete from history
|
||||
items:
|
||||
type: string
|
||||
format: uuid
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/history/{prompt_id}:
|
||||
get:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Get specific history item
|
||||
description: Returns a specific history item by ID
|
||||
operationId: getHistoryItem
|
||||
parameters:
|
||||
- name: prompt_id
|
||||
in: path
|
||||
description: ID of the prompt to retrieve
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
format: uuid
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HistoryItem'
|
||||
|
||||
/object_info:
|
||||
get:
|
||||
tags:
|
||||
- node
|
||||
summary: Get all node information
|
||||
description: Returns information about all available nodes
|
||||
operationId: getNodeInfo
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/NodeInfo'
|
||||
|
||||
/object_info/{node_class}:
|
||||
get:
|
||||
tags:
|
||||
- node
|
||||
summary: Get specific node information
|
||||
description: Returns information about a specific node class
|
||||
operationId: getNodeClassInfo
|
||||
parameters:
|
||||
- name: node_class
|
||||
in: path
|
||||
description: Name of the node class
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/NodeInfo'
|
||||
|
||||
/upload/image:
|
||||
post:
|
||||
tags:
|
||||
- image
|
||||
summary: Upload an image
|
||||
description: Uploads an image to the server
|
||||
operationId: uploadImage
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
image:
|
||||
type: string
|
||||
format: binary
|
||||
description: The image file to upload
|
||||
overwrite:
|
||||
type: string
|
||||
description: Whether to overwrite if file exists (true/false)
|
||||
type:
|
||||
type: string
|
||||
enum: [input, temp, output]
|
||||
description: Type of directory to store the image in
|
||||
subfolder:
|
||||
type: string
|
||||
description: Subfolder to store the image in
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Filename of the uploaded image
|
||||
subfolder:
|
||||
type: string
|
||||
description: Subfolder the image was stored in
|
||||
type:
|
||||
type: string
|
||||
description: Type of directory the image was stored in
|
||||
'400':
|
||||
description: Bad request
|
||||
|
||||
/upload/mask:
|
||||
post:
|
||||
tags:
|
||||
- image
|
||||
summary: Upload a mask for an image
|
||||
description: Uploads a mask image and applies it to a referenced original image
|
||||
operationId: uploadMask
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
image:
|
||||
type: string
|
||||
format: binary
|
||||
description: The mask image file to upload
|
||||
original_ref:
|
||||
type: string
|
||||
description: JSON string containing reference to the original image
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Filename of the uploaded mask
|
||||
subfolder:
|
||||
type: string
|
||||
description: Subfolder the mask was stored in
|
||||
type:
|
||||
type: string
|
||||
description: Type of directory the mask was stored in
|
||||
'400':
|
||||
description: Bad request
|
||||
|
||||
/view:
|
||||
get:
|
||||
tags:
|
||||
- image
|
||||
summary: View an image
|
||||
description: Retrieves an image from the server
|
||||
operationId: viewImage
|
||||
parameters:
|
||||
- name: filename
|
||||
in: query
|
||||
description: Name of the file to retrieve
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: type
|
||||
in: query
|
||||
description: Type of directory to retrieve from
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
enum: [input, temp, output]
|
||||
default: output
|
||||
- name: subfolder
|
||||
in: query
|
||||
description: Subfolder to retrieve from
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: preview
|
||||
in: query
|
||||
description: Preview options (format;quality)
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: channel
|
||||
in: query
|
||||
description: Channel to retrieve (rgb, a, rgba)
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
enum: [rgb, a, rgba]
|
||||
default: rgba
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
image/*:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
'400':
|
||||
description: Bad request
|
||||
'404':
|
||||
description: File not found
|
||||
|
||||
/view_metadata/{folder_name}:
|
||||
get:
|
||||
tags:
|
||||
- model
|
||||
summary: View model metadata
|
||||
description: Retrieves metadata from a safetensors file
|
||||
operationId: viewModelMetadata
|
||||
parameters:
|
||||
- name: folder_name
|
||||
in: path
|
||||
description: Name of the model folder
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: filename
|
||||
in: query
|
||||
description: Name of the safetensors file
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
'404':
|
||||
description: File not found
|
||||
|
||||
/models:
|
||||
get:
|
||||
tags:
|
||||
- model
|
||||
summary: Get model types
|
||||
description: Returns a list of available model types
|
||||
operationId: getModelTypes
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
/models/{folder}:
|
||||
get:
|
||||
tags:
|
||||
- model
|
||||
summary: Get models of a specific type
|
||||
description: Returns a list of available models of a specific type
|
||||
operationId: getModels
|
||||
parameters:
|
||||
- name: folder
|
||||
in: path
|
||||
description: Model type folder
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
'404':
|
||||
description: Folder not found
|
||||
|
||||
/embeddings:
|
||||
get:
|
||||
tags:
|
||||
- model
|
||||
summary: Get embeddings
|
||||
description: Returns a list of available embeddings
|
||||
operationId: getEmbeddings
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
/extensions:
|
||||
get:
|
||||
tags:
|
||||
- system
|
||||
summary: Get extensions
|
||||
description: Returns a list of available extensions
|
||||
operationId: getExtensions
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
/system_stats:
|
||||
get:
|
||||
tags:
|
||||
- system
|
||||
summary: Get system statistics
|
||||
description: Returns system information including RAM, VRAM, and ComfyUI version
|
||||
operationId: getSystemStats
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SystemStats'
|
||||
|
||||
/ws:
|
||||
get:
|
||||
tags:
|
||||
- workflow
|
||||
summary: WebSocket connection
|
||||
description: |
|
||||
Establishes a WebSocket connection for real-time communication.
|
||||
This endpoint is used for receiving progress updates, status changes, and results from workflow executions.
|
||||
operationId: webSocketConnect
|
||||
parameters:
|
||||
- name: clientId
|
||||
in: query
|
||||
description: Optional client ID for reconnection
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'101':
|
||||
description: Switching Protocols to WebSocket
|
||||
|
||||
/internal/logs:
|
||||
get:
|
||||
tags:
|
||||
- internal
|
||||
summary: Get logs
|
||||
description: Returns system logs as a single string
|
||||
operationId: getLogs
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: string
|
||||
|
||||
/internal/logs/raw:
|
||||
get:
|
||||
tags:
|
||||
- internal
|
||||
summary: Get raw logs
|
||||
description: Returns raw system logs with terminal size information
|
||||
operationId: getRawLogs
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
entries:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
t:
|
||||
type: string
|
||||
description: Timestamp
|
||||
m:
|
||||
type: string
|
||||
description: Message
|
||||
size:
|
||||
type: object
|
||||
properties:
|
||||
cols:
|
||||
type: integer
|
||||
description: Terminal columns
|
||||
rows:
|
||||
type: integer
|
||||
description: Terminal rows
|
||||
|
||||
/internal/logs/subscribe:
|
||||
patch:
|
||||
tags:
|
||||
- internal
|
||||
summary: Subscribe to logs
|
||||
description: Subscribe or unsubscribe to log updates
|
||||
operationId: subscribeToLogs
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
clientId:
|
||||
type: string
|
||||
description: Client ID
|
||||
enabled:
|
||||
type: boolean
|
||||
description: Whether to enable or disable subscription
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/internal/folder_paths:
|
||||
get:
|
||||
tags:
|
||||
- internal
|
||||
summary: Get folder paths
|
||||
description: Returns a map of folder names to their paths
|
||||
operationId: getFolderPaths
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
|
||||
/internal/files/{directory_type}:
|
||||
get:
|
||||
tags:
|
||||
- internal
|
||||
summary: Get files
|
||||
description: Returns a list of files in a specific directory type
|
||||
operationId: getFiles
|
||||
parameters:
|
||||
- name: directory_type
|
||||
in: path
|
||||
description: Type of directory (output, input, temp)
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
enum: [output, input, temp]
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
'400':
|
||||
description: Invalid directory type
|
||||
|
||||
components:
|
||||
schemas:
|
||||
PromptRequest:
|
||||
type: object
|
||||
required:
|
||||
- prompt
|
||||
properties:
|
||||
prompt:
|
||||
type: object
|
||||
description: The workflow graph to execute
|
||||
additionalProperties: true
|
||||
number:
|
||||
type: number
|
||||
description: Priority number for the queue (lower numbers have higher priority)
|
||||
front:
|
||||
type: boolean
|
||||
description: If true, adds the prompt to the front of the queue
|
||||
extra_data:
|
||||
type: object
|
||||
description: Extra data to be associated with the prompt
|
||||
additionalProperties: true
|
||||
client_id:
|
||||
type: string
|
||||
description: Client ID for attribution of the prompt
|
||||
|
||||
PromptResponse:
|
||||
type: object
|
||||
properties:
|
||||
prompt_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: Unique identifier for the prompt execution
|
||||
number:
|
||||
type: number
|
||||
description: Priority number in the queue
|
||||
node_errors:
|
||||
type: object
|
||||
description: Any errors in the nodes of the prompt
|
||||
additionalProperties: true
|
||||
|
||||
ErrorResponse:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: Error type
|
||||
message:
|
||||
type: string
|
||||
description: Error message
|
||||
details:
|
||||
type: string
|
||||
description: Detailed error information
|
||||
extra_info:
|
||||
type: object
|
||||
description: Additional error information
|
||||
additionalProperties: true
|
||||
node_errors:
|
||||
type: object
|
||||
description: Node-specific errors
|
||||
additionalProperties: true
|
||||
|
||||
PromptInfo:
|
||||
type: object
|
||||
properties:
|
||||
exec_info:
|
||||
type: object
|
||||
properties:
|
||||
queue_remaining:
|
||||
type: integer
|
||||
description: Number of items remaining in the queue
|
||||
|
||||
QueueInfo:
|
||||
type: object
|
||||
properties:
|
||||
queue_running:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
description: Currently running items
|
||||
additionalProperties: true
|
||||
queue_pending:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
description: Pending items in the queue
|
||||
additionalProperties: true
|
||||
|
||||
HistoryItem:
|
||||
type: object
|
||||
properties:
|
||||
prompt_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: Unique identifier for the prompt
|
||||
prompt:
|
||||
type: object
|
||||
description: The workflow graph that was executed
|
||||
additionalProperties: true
|
||||
extra_data:
|
||||
type: object
|
||||
description: Additional data associated with the execution
|
||||
additionalProperties: true
|
||||
outputs:
|
||||
type: object
|
||||
description: Output data from the execution
|
||||
additionalProperties: true
|
||||
|
||||
NodeInfo:
|
||||
type: object
|
||||
properties:
|
||||
input:
|
||||
type: object
|
||||
description: Input specifications for the node
|
||||
additionalProperties: true
|
||||
input_order:
|
||||
type: object
|
||||
description: Order of inputs for display
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Output types of the node
|
||||
output_is_list:
|
||||
type: array
|
||||
items:
|
||||
type: boolean
|
||||
description: Whether each output is a list
|
||||
output_name:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Names of the outputs
|
||||
name:
|
||||
type: string
|
||||
description: Internal name of the node
|
||||
display_name:
|
||||
type: string
|
||||
description: Display name of the node
|
||||
description:
|
||||
type: string
|
||||
description: Description of the node
|
||||
python_module:
|
||||
type: string
|
||||
description: Python module implementing the node
|
||||
category:
|
||||
type: string
|
||||
description: Category of the node
|
||||
output_node:
|
||||
type: boolean
|
||||
description: Whether this is an output node
|
||||
output_tooltips:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Tooltips for outputs
|
||||
deprecated:
|
||||
type: boolean
|
||||
description: Whether the node is deprecated
|
||||
experimental:
|
||||
type: boolean
|
||||
description: Whether the node is experimental
|
||||
api_node:
|
||||
type: boolean
|
||||
description: Whether this is an API node
|
||||
|
||||
SystemStats:
|
||||
type: object
|
||||
properties:
|
||||
system:
|
||||
type: object
|
||||
properties:
|
||||
os:
|
||||
type: string
|
||||
description: Operating system
|
||||
ram_total:
|
||||
type: number
|
||||
description: Total system RAM in bytes
|
||||
ram_free:
|
||||
type: number
|
||||
description: Free system RAM in bytes
|
||||
comfyui_version:
|
||||
type: string
|
||||
description: ComfyUI version
|
||||
python_version:
|
||||
type: string
|
||||
description: Python version
|
||||
pytorch_version:
|
||||
type: string
|
||||
description: PyTorch version
|
||||
embedded_python:
|
||||
type: boolean
|
||||
description: Whether using embedded Python
|
||||
argv:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Command line arguments
|
||||
devices:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Device name
|
||||
type:
|
||||
type: string
|
||||
description: Device type
|
||||
index:
|
||||
type: integer
|
||||
description: Device index
|
||||
vram_total:
|
||||
type: number
|
||||
description: Total VRAM in bytes
|
||||
vram_free:
|
||||
type: number
|
||||
description: Free VRAM in bytes
|
||||
torch_vram_total:
|
||||
type: number
|
||||
description: Total VRAM as reported by PyTorch
|
||||
torch_vram_free:
|
||||
type: number
|
||||
description: Free VRAM as reported by PyTorch
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.39"
|
||||
version = "0.3.34"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
comfyui-frontend-package==1.21.3
|
||||
comfyui-workflow-templates==0.1.23
|
||||
comfyui-embedded-docs==0.2.0
|
||||
comfyui-frontend-package==1.19.9
|
||||
comfyui-workflow-templates==0.1.14
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
18
server.py
18
server.py
@@ -29,8 +29,6 @@ import comfy.model_management
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
from comfy_api.v3.io import ComfyNodeV3
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
@@ -161,7 +159,7 @@ class PromptServer():
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
self.prompt_queue = None
|
||||
self.loop = loop
|
||||
self.messages = asyncio.Queue()
|
||||
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||
@@ -228,7 +226,7 @@ class PromptServer():
|
||||
return response
|
||||
|
||||
@routes.get("/embeddings")
|
||||
def get_embeddings(request):
|
||||
def get_embeddings(self):
|
||||
embeddings = folder_paths.get_filename_list("embeddings")
|
||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||
|
||||
@@ -284,6 +282,7 @@ class PromptServer():
|
||||
a.update(f.read())
|
||||
b.update(image.file.read())
|
||||
image.file.seek(0)
|
||||
f.close()
|
||||
return a.hexdigest() == b.hexdigest()
|
||||
return False
|
||||
|
||||
@@ -556,8 +555,6 @@ class PromptServer():
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
if isinstance(obj_class, ComfyNodeV3):
|
||||
return obj_class.GET_NODE_INFO_V1()
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
@@ -624,7 +621,7 @@ class PromptServer():
|
||||
@routes.get("/queue")
|
||||
async def get_queue(request):
|
||||
queue_info = {}
|
||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||
current_queue = self.prompt_queue.get_current_queue()
|
||||
queue_info['queue_running'] = current_queue[0]
|
||||
queue_info['queue_pending'] = current_queue[1]
|
||||
return web.json_response(queue_info)
|
||||
@@ -749,13 +746,6 @@ class PromptServer():
|
||||
web.static('/templates', workflow_templates_path)
|
||||
])
|
||||
|
||||
# Serve embedded documentation from the package
|
||||
embedded_docs_path = FrontendManager.embedded_docs_path()
|
||||
if embedded_docs_path:
|
||||
self.app.add_routes([
|
||||
web.static('/docs', embedded_docs_path)
|
||||
])
|
||||
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root),
|
||||
])
|
||||
|
||||
74
tests-api/README.md
Normal file
74
tests-api/README.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# ComfyUI API Testing
|
||||
|
||||
This directory contains tests for validating the ComfyUI OpenAPI specification against a running instance of ComfyUI.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Install the required dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Make sure you have a running instance of ComfyUI (default: http://127.0.0.1:8188)
|
||||
|
||||
## Running the Tests
|
||||
|
||||
Run all tests with pytest:
|
||||
|
||||
```bash
|
||||
cd tests-api
|
||||
pytest
|
||||
```
|
||||
|
||||
Run specific test files:
|
||||
|
||||
```bash
|
||||
pytest test_spec_validation.py
|
||||
pytest test_endpoint_existence.py
|
||||
pytest test_schema_validation.py
|
||||
pytest test_api_by_tag.py
|
||||
```
|
||||
|
||||
Run tests with more verbose output:
|
||||
|
||||
```bash
|
||||
pytest -v
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
The tests are organized into several categories:
|
||||
|
||||
1. **Spec Validation**: Validates that the OpenAPI specification is valid.
|
||||
2. **Endpoint Existence**: Tests that the endpoints defined in the spec exist on the server.
|
||||
3. **Schema Validation**: Tests that the server responses match the schemas defined in the spec.
|
||||
4. **Tag-Based Tests**: Tests that the API's tag organization is consistent.
|
||||
|
||||
## Using a Different Server
|
||||
|
||||
By default, the tests connect to `http://127.0.0.1:8188`. To test against a different server, set the `COMFYUI_SERVER_URL` environment variable:
|
||||
|
||||
```bash
|
||||
COMFYUI_SERVER_URL=http://example.com:8188 pytest
|
||||
```
|
||||
|
||||
## Test Structure
|
||||
|
||||
- `conftest.py`: Contains pytest fixtures used by the tests.
|
||||
- `utils/`: Contains utility functions for working with the OpenAPI spec.
|
||||
- `test_*.py`: The actual test files.
|
||||
- `resources/`: Contains resources used by the tests (e.g., sample workflows).
|
||||
|
||||
## Extending the Tests
|
||||
|
||||
To add new tests:
|
||||
|
||||
1. For testing new endpoints, add them to the appropriate test file based on their category.
|
||||
2. For testing more complex functionality, create a new test file following the established patterns.
|
||||
|
||||
## Notes
|
||||
|
||||
- Tests that require a running server will be skipped if the server is not available.
|
||||
- Some tests may fail if the server doesn't match the specification exactly.
|
||||
- The tests don't modify any data on the server (they're read-only).
|
||||
141
tests-api/conftest.py
Normal file
141
tests-api/conftest.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Test fixtures for API testing
|
||||
"""
|
||||
import os
|
||||
import pytest
|
||||
import yaml
|
||||
import requests
|
||||
import logging
|
||||
from typing import Dict, Any, Generator, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default server configuration
|
||||
DEFAULT_SERVER_URL = "http://127.0.0.1:8188"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_spec_path() -> str:
|
||||
"""
|
||||
Get the path to the OpenAPI specification file
|
||||
|
||||
Returns:
|
||||
Path to the OpenAPI specification file
|
||||
"""
|
||||
return os.path.abspath(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
"openapi.yaml"
|
||||
))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_spec(api_spec_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load the OpenAPI specification
|
||||
|
||||
Args:
|
||||
api_spec_path: Path to the spec file
|
||||
|
||||
Returns:
|
||||
Parsed OpenAPI specification
|
||||
"""
|
||||
with open(api_spec_path, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base_url() -> str:
|
||||
"""
|
||||
Get the base URL for the API server
|
||||
|
||||
Returns:
|
||||
Base URL string
|
||||
"""
|
||||
# Allow overriding via environment variable
|
||||
return os.environ.get("COMFYUI_SERVER_URL", DEFAULT_SERVER_URL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_available(base_url: str) -> bool:
|
||||
"""
|
||||
Check if the server is available
|
||||
|
||||
Args:
|
||||
base_url: Base URL for the API
|
||||
|
||||
Returns:
|
||||
True if the server is available, False otherwise
|
||||
"""
|
||||
try:
|
||||
response = requests.get(base_url, timeout=2)
|
||||
return response.status_code == 200
|
||||
except requests.RequestException:
|
||||
logger.warning(f"Server at {base_url} is not available")
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client(base_url: str) -> Generator[Optional[requests.Session], None, None]:
|
||||
"""
|
||||
Create a requests session for API testing
|
||||
|
||||
Args:
|
||||
base_url: Base URL for the API
|
||||
|
||||
Yields:
|
||||
Requests session configured for the API
|
||||
"""
|
||||
session = requests.Session()
|
||||
|
||||
# Helper function to construct URLs
|
||||
def get_url(path: str) -> str:
|
||||
return urljoin(base_url, path)
|
||||
|
||||
# Add url helper to the session
|
||||
session.get_url = get_url # type: ignore
|
||||
|
||||
yield session
|
||||
|
||||
# Cleanup
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_get_json(api_client: requests.Session):
|
||||
"""
|
||||
Helper fixture for making GET requests and parsing JSON responses
|
||||
|
||||
Args:
|
||||
api_client: API client session
|
||||
|
||||
Returns:
|
||||
Function that makes GET requests and returns JSON
|
||||
"""
|
||||
def _get_json(path: str, **kwargs):
|
||||
url = api_client.get_url(path) # type: ignore
|
||||
response = api_client.get(url, **kwargs)
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
return _get_json
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def require_server(server_available):
|
||||
"""
|
||||
Skip tests if server is not available
|
||||
|
||||
Args:
|
||||
server_available: Whether the server is available
|
||||
"""
|
||||
if not server_available:
|
||||
pytest.skip("Server is not available")
|
||||
6
tests-api/requirements.txt
Normal file
6
tests-api/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
pytest>=7.0.0
|
||||
pytest-asyncio>=0.21.0
|
||||
openapi-spec-validator>=0.5.0
|
||||
jsonschema>=4.17.0
|
||||
requests>=2.28.0
|
||||
pyyaml>=6.0.0
|
||||
279
tests-api/test_api_by_tag.py
Normal file
279
tests-api/test_api_by_tag.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Tests for API endpoints grouped by tags
|
||||
"""
|
||||
import pytest
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any, Set
|
||||
|
||||
# Use a direct import with the full path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# Define functions inline to avoid import issues
|
||||
def get_all_endpoints(spec):
|
||||
"""
|
||||
Extract all endpoints from an OpenAPI spec
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'tags': operation.get('tags', []),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
def get_all_tags(spec):
|
||||
"""
|
||||
Get all tags used in the API spec
|
||||
"""
|
||||
tags = set()
|
||||
|
||||
for path_item in spec['paths'].values():
|
||||
for operation in path_item.values():
|
||||
if isinstance(operation, dict) and 'tags' in operation:
|
||||
tags.update(operation['tags'])
|
||||
|
||||
return tags
|
||||
|
||||
def extract_endpoints_by_tag(spec, tag):
|
||||
"""
|
||||
Extract all endpoints with a specific tag
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
if tag in operation.get('tags', []):
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_tags(api_spec: Dict[str, Any]) -> Set[str]:
|
||||
"""
|
||||
Get all tags from the API spec
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
|
||||
Returns:
|
||||
Set of tag names
|
||||
"""
|
||||
return get_all_tags(api_spec)
|
||||
|
||||
|
||||
def test_api_has_tags(api_tags: Set[str]):
|
||||
"""
|
||||
Test that the API has defined tags
|
||||
|
||||
Args:
|
||||
api_tags: Set of tags
|
||||
"""
|
||||
assert len(api_tags) > 0, "API spec should have at least one tag"
|
||||
|
||||
# Log the tags
|
||||
logger.info(f"API spec has the following tags: {sorted(api_tags)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tag", [
|
||||
"workflow",
|
||||
"image",
|
||||
"model",
|
||||
"node",
|
||||
"system"
|
||||
])
|
||||
def test_core_tags_exist(api_tags: Set[str], tag: str):
|
||||
"""
|
||||
Test that core tags exist in the API spec
|
||||
|
||||
Args:
|
||||
api_tags: Set of tags
|
||||
tag: Tag to check
|
||||
"""
|
||||
assert tag in api_tags, f"API spec should have '{tag}' tag"
|
||||
|
||||
|
||||
def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'workflow' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "workflow")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'workflow' tag"
|
||||
|
||||
# Check for key workflow endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/prompt" in endpoint_paths, "Workflow tag should include /prompt endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'workflow' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_image_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'image' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "image")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'image' tag"
|
||||
|
||||
# Check for key image endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/upload/image" in endpoint_paths, "Image tag should include /upload/image endpoint"
|
||||
assert "/view" in endpoint_paths, "Image tag should include /view endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'image' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_model_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'model' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "model")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'model' tag"
|
||||
|
||||
# Check for key model endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/models" in endpoint_paths, "Model tag should include /models endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'model' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_node_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'node' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "node")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'node' tag"
|
||||
|
||||
# Check for key node endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/object_info" in endpoint_paths, "Node tag should include /object_info endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'node' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_system_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'system' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "system")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'system' tag"
|
||||
|
||||
# Check for key system endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/system_stats" in endpoint_paths, "System tag should include /system_stats endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'system' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'internal' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "internal")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'internal' tag"
|
||||
|
||||
# Check for key internal endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/internal/logs" in endpoint_paths, "Internal tag should include /internal/logs endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'internal' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_operation_ids_match_tag(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that operation IDs follow a consistent pattern with their tag
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
failures = []
|
||||
|
||||
for path, path_item in api_spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
if 'operationId' in operation and 'tags' in operation and operation['tags']:
|
||||
op_id = operation['operationId']
|
||||
primary_tag = operation['tags'][0].lower()
|
||||
|
||||
# Check if operationId starts with primary tag prefix
|
||||
# This is a common convention, but might need adjusting
|
||||
if not (op_id.startswith(primary_tag) or
|
||||
any(op_id.lower().startswith(f"{tag.lower()}") for tag in operation['tags'])):
|
||||
failures.append({
|
||||
'path': path,
|
||||
'method': method,
|
||||
'operationId': op_id,
|
||||
'primary_tag': primary_tag
|
||||
})
|
||||
|
||||
# Log failures for diagnosis but don't fail the test
|
||||
# as this is a style/convention check
|
||||
if failures:
|
||||
logger.warning(f"Found {len(failures)} operationIds that don't align with their tags:")
|
||||
for f in failures:
|
||||
logger.warning(f" {f['method'].upper()} {f['path']} - operationId: {f['operationId']}, primary tag: {f['primary_tag']}")
|
||||
240
tests-api/test_endpoint_existence.py
Normal file
240
tests-api/test_endpoint_existence.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Tests for endpoint existence and basic response codes
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
# Use a direct import with the full path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# Define get_all_endpoints function inline to avoid import issues
|
||||
def get_all_endpoints(spec):
|
||||
"""
|
||||
Extract all endpoints from an OpenAPI spec
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
|
||||
Returns:
|
||||
List of dicts with path, method, and tags for each endpoint
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'tags': operation.get('tags', []),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all endpoints from the API spec
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
|
||||
Returns:
|
||||
List of endpoint information
|
||||
"""
|
||||
return get_all_endpoints(api_spec)
|
||||
|
||||
|
||||
def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]):
|
||||
"""
|
||||
Test that endpoints are defined in the spec
|
||||
|
||||
Args:
|
||||
all_endpoints: List of endpoint information
|
||||
"""
|
||||
# Simple check that we have endpoints defined
|
||||
assert len(all_endpoints) > 0, "No endpoints defined in the OpenAPI spec"
|
||||
|
||||
# Log the endpoints for informational purposes
|
||||
logger.info(f"Found {len(all_endpoints)} endpoints in the OpenAPI spec")
|
||||
for endpoint in all_endpoints:
|
||||
logger.info(f"{endpoint['method'].upper()} {endpoint['path']} - {endpoint['summary']}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endpoint_path", [
|
||||
"/", # Root path
|
||||
"/prompt", # Get prompt info
|
||||
"/queue", # Get queue
|
||||
"/models", # Get model types
|
||||
"/object_info", # Get node info
|
||||
"/system_stats" # Get system stats
|
||||
])
|
||||
def test_basic_get_endpoints(require_server, api_client, endpoint_path: str):
|
||||
"""
|
||||
Test that basic GET endpoints exist and respond
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
endpoint_path: Path to test
|
||||
"""
|
||||
url = api_client.get_url(endpoint_path) # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
# We're just checking that the endpoint exists and returns some kind of response
|
||||
# Not necessarily a 200 status code
|
||||
assert response.status_code not in [404, 405], f"Endpoint {endpoint_path} does not exist"
|
||||
|
||||
logger.info(f"Endpoint {endpoint_path} exists with status code {response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
|
||||
|
||||
|
||||
def test_websocket_endpoint_exists(require_server, base_url: str):
|
||||
"""
|
||||
Test that the WebSocket endpoint exists
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
base_url: Base server URL
|
||||
"""
|
||||
ws_url = urljoin(base_url, "/ws")
|
||||
|
||||
# For WebSocket, we can't use a normal GET request
|
||||
# Instead, we make a HEAD request to check if the endpoint exists
|
||||
try:
|
||||
response = requests.head(ws_url)
|
||||
|
||||
# WebSocket endpoints often return a 400 Bad Request for HEAD requests
|
||||
# but a 404 would indicate the endpoint doesn't exist
|
||||
assert response.status_code != 404, "WebSocket endpoint /ws does not exist"
|
||||
|
||||
logger.info(f"WebSocket endpoint exists with status code {response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to WebSocket endpoint failed: {str(e)}")
|
||||
|
||||
|
||||
def test_api_models_folder_endpoint(require_server, api_client):
|
||||
"""
|
||||
Test that the /models/{folder} endpoint exists and responds
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
"""
|
||||
# First get available model types
|
||||
models_url = api_client.get_url("/models") # type: ignore
|
||||
|
||||
try:
|
||||
models_response = api_client.get(models_url)
|
||||
assert models_response.status_code == 200, "Failed to get model types"
|
||||
|
||||
model_types = models_response.json()
|
||||
|
||||
# Skip if no model types available
|
||||
if not model_types:
|
||||
pytest.skip("No model types available to test")
|
||||
|
||||
# Test with the first model type
|
||||
model_type = model_types[0]
|
||||
models_folder_url = api_client.get_url(f"/models/{model_type}") # type: ignore
|
||||
|
||||
folder_response = api_client.get(models_folder_url)
|
||||
|
||||
# We're just checking that the endpoint exists
|
||||
assert folder_response.status_code != 404, f"Endpoint /models/{model_type} does not exist"
|
||||
|
||||
logger.info(f"Endpoint /models/{model_type} exists with status code {folder_response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request failed: {str(e)}")
|
||||
except (ValueError, KeyError, IndexError) as e:
|
||||
pytest.fail(f"Failed to process response: {str(e)}")
|
||||
|
||||
|
||||
def test_api_object_info_node_endpoint(require_server, api_client):
|
||||
"""
|
||||
Test that the /object_info/{node_class} endpoint exists and responds
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
"""
|
||||
# First get available node classes
|
||||
objects_url = api_client.get_url("/object_info") # type: ignore
|
||||
|
||||
try:
|
||||
objects_response = api_client.get(objects_url)
|
||||
assert objects_response.status_code == 200, "Failed to get object info"
|
||||
|
||||
node_classes = objects_response.json()
|
||||
|
||||
# Skip if no node classes available
|
||||
if not node_classes:
|
||||
pytest.skip("No node classes available to test")
|
||||
|
||||
# Test with the first node class
|
||||
node_class = next(iter(node_classes.keys()))
|
||||
node_url = api_client.get_url(f"/object_info/{node_class}") # type: ignore
|
||||
|
||||
node_response = api_client.get(node_url)
|
||||
|
||||
# We're just checking that the endpoint exists
|
||||
assert node_response.status_code != 404, f"Endpoint /object_info/{node_class} does not exist"
|
||||
|
||||
logger.info(f"Endpoint /object_info/{node_class} exists with status code {node_response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request failed: {str(e)}")
|
||||
except (ValueError, KeyError, StopIteration) as e:
|
||||
pytest.fail(f"Failed to process response: {str(e)}")
|
||||
|
||||
|
||||
def test_internal_endpoints_exist(require_server, api_client):
|
||||
"""
|
||||
Test that internal endpoints exist
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
"""
|
||||
internal_endpoints = [
|
||||
"/internal/logs",
|
||||
"/internal/logs/raw",
|
||||
"/internal/folder_paths",
|
||||
"/internal/files/output"
|
||||
]
|
||||
|
||||
for endpoint in internal_endpoints:
|
||||
url = api_client.get_url(endpoint) # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
# We're just checking that the endpoint exists
|
||||
assert response.status_code != 404, f"Endpoint {endpoint} does not exist"
|
||||
|
||||
logger.info(f"Endpoint {endpoint} exists with status code {response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"Request to {endpoint} failed: {str(e)}")
|
||||
# Don't fail the test as internal endpoints might be restricted
|
||||
440
tests-api/test_schema_validation.py
Normal file
440
tests-api/test_schema_validation.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Tests for validating API responses against OpenAPI schema
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
# Use a direct import with the full path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# Define validation functions inline to avoid import issues
|
||||
def get_endpoint_schema(
|
||||
spec,
|
||||
path,
|
||||
method,
|
||||
status_code = '200'
|
||||
):
|
||||
"""
|
||||
Extract response schema for a specific endpoint from OpenAPI spec
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
# Handle path not found
|
||||
if path not in spec['paths']:
|
||||
return None
|
||||
|
||||
# Handle method not found
|
||||
if method not in spec['paths'][path]:
|
||||
return None
|
||||
|
||||
# Handle status code not found
|
||||
responses = spec['paths'][path][method].get('responses', {})
|
||||
if status_code not in responses:
|
||||
return None
|
||||
|
||||
# Handle no content defined
|
||||
if 'content' not in responses[status_code]:
|
||||
return None
|
||||
|
||||
# Get schema from first content type
|
||||
content_types = responses[status_code]['content']
|
||||
first_content_type = next(iter(content_types))
|
||||
|
||||
if 'schema' not in content_types[first_content_type]:
|
||||
return None
|
||||
|
||||
return content_types[first_content_type]['schema']
|
||||
|
||||
def resolve_schema_refs(schema, spec):
|
||||
"""
|
||||
Resolve $ref references in a schema
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
|
||||
for key, value in schema.items():
|
||||
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
|
||||
# Handle reference
|
||||
ref_path = value[2:].split('/')
|
||||
ref_value = spec
|
||||
for path_part in ref_path:
|
||||
ref_value = ref_value.get(path_part, {})
|
||||
|
||||
# Recursively resolve any refs in the referenced schema
|
||||
ref_value = resolve_schema_refs(ref_value, spec)
|
||||
result.update(ref_value)
|
||||
elif isinstance(value, dict):
|
||||
# Recursively resolve refs in nested dictionaries
|
||||
result[key] = resolve_schema_refs(value, spec)
|
||||
elif isinstance(value, list):
|
||||
# Recursively resolve refs in list items
|
||||
result[key] = [
|
||||
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
# Pass through other values
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
def validate_response(
|
||||
response_data,
|
||||
spec,
|
||||
path,
|
||||
method,
|
||||
status_code = '200'
|
||||
):
|
||||
"""
|
||||
Validate a response against the OpenAPI schema
|
||||
"""
|
||||
schema = get_endpoint_schema(spec, path, method, status_code)
|
||||
|
||||
if schema is None:
|
||||
return {
|
||||
'valid': False,
|
||||
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
|
||||
}
|
||||
|
||||
# Resolve any $ref in the schema
|
||||
resolved_schema = resolve_schema_refs(schema, spec)
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
jsonschema.validate(instance=response_data, schema=resolved_schema)
|
||||
return {'valid': True, 'errors': []}
|
||||
except jsonschema.exceptions.ValidationError as e:
|
||||
# Extract more detailed error information
|
||||
path = ".".join(str(p) for p in e.path) if e.path else "root"
|
||||
instance = e.instance if not isinstance(e.instance, dict) else "..."
|
||||
schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown"
|
||||
|
||||
detailed_error = (
|
||||
f"Validation error at path: {path}\n"
|
||||
f"Schema path: {schema_path}\n"
|
||||
f"Error message: {e.message}\n"
|
||||
f"Failed instance: {instance}\n"
|
||||
)
|
||||
|
||||
return {'valid': False, 'errors': [detailed_error]}
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endpoint_path,method", [
|
||||
("/system_stats", "get"),
|
||||
("/prompt", "get"),
|
||||
("/queue", "get"),
|
||||
("/models", "get"),
|
||||
("/embeddings", "get")
|
||||
])
|
||||
def test_response_schema_validation(
|
||||
require_server,
|
||||
api_client,
|
||||
api_spec: Dict[str, Any],
|
||||
endpoint_path: str,
|
||||
method: str
|
||||
):
|
||||
"""
|
||||
Test that API responses match the defined schema
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
endpoint_path: Path to test
|
||||
method: HTTP method to test
|
||||
"""
|
||||
url = api_client.get_url(endpoint_path) # type: ignore
|
||||
|
||||
# Skip if no schema defined
|
||||
schema = get_endpoint_schema(api_spec, endpoint_path, method)
|
||||
if not schema:
|
||||
pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}")
|
||||
|
||||
try:
|
||||
if method.lower() == "get":
|
||||
response = api_client.get(url)
|
||||
else:
|
||||
pytest.skip(f"Method {method} not implemented for automated testing")
|
||||
return
|
||||
|
||||
# Skip if response is not 200
|
||||
if response.status_code != 200:
|
||||
pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}")
|
||||
return
|
||||
|
||||
# Skip if response is not JSON
|
||||
try:
|
||||
response_data = response.json()
|
||||
except ValueError:
|
||||
pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON")
|
||||
return
|
||||
|
||||
# Validate the response
|
||||
validation_result = validate_response(
|
||||
response_data,
|
||||
api_spec,
|
||||
endpoint_path,
|
||||
method
|
||||
)
|
||||
|
||||
if validation_result['valid']:
|
||||
logger.info(f"Response from {method.upper()} {endpoint_path} matches schema")
|
||||
else:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}")
|
||||
|
||||
assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
|
||||
|
||||
|
||||
def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test the system_stats endpoint response in detail
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
url = api_client.get_url("/system_stats") # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
assert response.status_code == 200, "Failed to get system stats"
|
||||
|
||||
# Parse response
|
||||
stats = response.json()
|
||||
|
||||
# Validate high-level structure
|
||||
assert 'system' in stats, "Response missing 'system' field"
|
||||
assert 'devices' in stats, "Response missing 'devices' field"
|
||||
|
||||
# Validate system fields
|
||||
system = stats['system']
|
||||
assert 'os' in system, "System missing 'os' field"
|
||||
assert 'ram_total' in system, "System missing 'ram_total' field"
|
||||
assert 'ram_free' in system, "System missing 'ram_free' field"
|
||||
assert 'comfyui_version' in system, "System missing 'comfyui_version' field"
|
||||
|
||||
# Validate devices fields
|
||||
devices = stats['devices']
|
||||
assert isinstance(devices, list), "Devices should be a list"
|
||||
|
||||
if devices:
|
||||
device = devices[0]
|
||||
assert 'name' in device, "Device missing 'name' field"
|
||||
assert 'type' in device, "Device missing 'type' field"
|
||||
assert 'vram_total' in device, "Device missing 'vram_total' field"
|
||||
assert 'vram_free' in device, "Device missing 'vram_free' field"
|
||||
|
||||
# Perform schema validation
|
||||
validation_result = validate_response(
|
||||
stats,
|
||||
api_spec,
|
||||
"/system_stats",
|
||||
"get"
|
||||
)
|
||||
|
||||
# Print detailed error if validation fails
|
||||
if not validation_result['valid']:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for /system_stats: {error}")
|
||||
|
||||
# Print schema details for debugging
|
||||
schema = get_endpoint_schema(api_spec, "/system_stats", "get")
|
||||
if schema:
|
||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
||||
|
||||
# Print sample of the response
|
||||
logger.error(f"Response:\n{json.dumps(stats, indent=2)}")
|
||||
|
||||
assert validation_result['valid'], "System stats response does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to /system_stats failed: {str(e)}")
|
||||
|
||||
|
||||
def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test the models endpoint response
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
url = api_client.get_url("/models") # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
assert response.status_code == 200, "Failed to get models"
|
||||
|
||||
# Parse response
|
||||
models = response.json()
|
||||
|
||||
# Validate it's a list
|
||||
assert isinstance(models, list), "Models response should be a list"
|
||||
|
||||
# Each item should be a string
|
||||
for model in models:
|
||||
assert isinstance(model, str), "Each model type should be a string"
|
||||
|
||||
# Perform schema validation
|
||||
validation_result = validate_response(
|
||||
models,
|
||||
api_spec,
|
||||
"/models",
|
||||
"get"
|
||||
)
|
||||
|
||||
# Print detailed error if validation fails
|
||||
if not validation_result['valid']:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for /models: {error}")
|
||||
|
||||
# Print schema details for debugging
|
||||
schema = get_endpoint_schema(api_spec, "/models", "get")
|
||||
if schema:
|
||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
||||
|
||||
# Print response
|
||||
sample_models = models[:5] if isinstance(models, list) else models
|
||||
logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}")
|
||||
|
||||
assert validation_result['valid'], "Models response does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to /models failed: {str(e)}")
|
||||
|
||||
|
||||
def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test the object_info endpoint response
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
url = api_client.get_url("/object_info") # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
assert response.status_code == 200, "Failed to get object info"
|
||||
|
||||
# Parse response
|
||||
objects = response.json()
|
||||
|
||||
# Validate it's an object
|
||||
assert isinstance(objects, dict), "Object info response should be an object"
|
||||
|
||||
# Check if we have any objects
|
||||
if objects:
|
||||
# Get the first object
|
||||
first_obj_name = next(iter(objects.keys()))
|
||||
first_obj = objects[first_obj_name]
|
||||
|
||||
# Validate first object has required fields
|
||||
assert 'input' in first_obj, "Object missing 'input' field"
|
||||
assert 'output' in first_obj, "Object missing 'output' field"
|
||||
assert 'name' in first_obj, "Object missing 'name' field"
|
||||
|
||||
# Perform schema validation
|
||||
validation_result = validate_response(
|
||||
objects,
|
||||
api_spec,
|
||||
"/object_info",
|
||||
"get"
|
||||
)
|
||||
|
||||
# Print detailed error if validation fails
|
||||
if not validation_result['valid']:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for /object_info: {error}")
|
||||
|
||||
# Print schema details for debugging
|
||||
schema = get_endpoint_schema(api_spec, "/object_info", "get")
|
||||
if schema:
|
||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
||||
|
||||
# Also print a small sample of the response
|
||||
sample = dict(list(objects.items())[:1]) if objects else {}
|
||||
logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}")
|
||||
|
||||
assert validation_result['valid'], "Object info response does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to /object_info failed: {str(e)}")
|
||||
except (KeyError, StopIteration) as e:
|
||||
pytest.fail(f"Failed to process response: {str(e)}")
|
||||
|
||||
|
||||
def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test the queue endpoint response
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
url = api_client.get_url("/queue") # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
assert response.status_code == 200, "Failed to get queue"
|
||||
|
||||
# Parse response
|
||||
queue = response.json()
|
||||
|
||||
# Validate structure
|
||||
assert 'queue_running' in queue, "Queue missing 'queue_running' field"
|
||||
assert 'queue_pending' in queue, "Queue missing 'queue_pending' field"
|
||||
|
||||
# Each should be a list
|
||||
assert isinstance(queue['queue_running'], list), "queue_running should be a list"
|
||||
assert isinstance(queue['queue_pending'], list), "queue_pending should be a list"
|
||||
|
||||
# Perform schema validation
|
||||
validation_result = validate_response(
|
||||
queue,
|
||||
api_spec,
|
||||
"/queue",
|
||||
"get"
|
||||
)
|
||||
|
||||
# Print detailed error if validation fails
|
||||
if not validation_result['valid']:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for /queue: {error}")
|
||||
|
||||
# Print schema details for debugging
|
||||
schema = get_endpoint_schema(api_spec, "/queue", "get")
|
||||
if schema:
|
||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
||||
|
||||
# Print response
|
||||
logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}")
|
||||
|
||||
assert validation_result['valid'], "Queue response does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to /queue failed: {str(e)}")
|
||||
144
tests-api/test_spec_validation.py
Normal file
144
tests-api/test_spec_validation.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Tests for validating the OpenAPI specification
|
||||
"""
|
||||
import pytest
|
||||
from openapi_spec_validator import validate_spec
|
||||
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def test_openapi_spec_is_valid(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the OpenAPI specification is valid
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
try:
|
||||
validate_spec(api_spec)
|
||||
except OpenAPISpecValidatorError as e:
|
||||
pytest.fail(f"OpenAPI spec validation failed: {str(e)}")
|
||||
|
||||
|
||||
def test_spec_has_info(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the OpenAPI spec has the required info section
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert 'info' in api_spec, "Spec must have info section"
|
||||
assert 'title' in api_spec['info'], "Info must have title"
|
||||
assert 'version' in api_spec['info'], "Info must have version"
|
||||
|
||||
|
||||
def test_spec_has_paths(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the OpenAPI spec has paths defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert 'paths' in api_spec, "Spec must have paths section"
|
||||
assert len(api_spec['paths']) > 0, "Spec must have at least one path"
|
||||
|
||||
|
||||
def test_spec_has_components(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the OpenAPI spec has components defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert 'components' in api_spec, "Spec must have components section"
|
||||
assert 'schemas' in api_spec['components'], "Components must have schemas"
|
||||
|
||||
|
||||
def test_workflow_endpoints_exist(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that core workflow endpoints are defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert '/prompt' in api_spec['paths'], "Spec must define /prompt endpoint"
|
||||
assert 'post' in api_spec['paths']['/prompt'], "Spec must define POST /prompt"
|
||||
assert 'get' in api_spec['paths']['/prompt'], "Spec must define GET /prompt"
|
||||
|
||||
|
||||
def test_image_endpoints_exist(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that core image endpoints are defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert '/upload/image' in api_spec['paths'], "Spec must define /upload/image endpoint"
|
||||
assert '/view' in api_spec['paths'], "Spec must define /view endpoint"
|
||||
|
||||
|
||||
def test_model_endpoints_exist(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that core model endpoints are defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert '/models' in api_spec['paths'], "Spec must define /models endpoint"
|
||||
assert '/models/{folder}' in api_spec['paths'], "Spec must define /models/{folder} endpoint"
|
||||
|
||||
|
||||
def test_operation_ids_are_unique(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that all operationIds are unique
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
operation_ids = []
|
||||
|
||||
for path, path_item in api_spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
if 'operationId' in operation:
|
||||
operation_ids.append(operation['operationId'])
|
||||
|
||||
# Check for duplicates
|
||||
duplicates = set([op_id for op_id in operation_ids if operation_ids.count(op_id) > 1])
|
||||
assert len(duplicates) == 0, f"Found duplicate operationIds: {duplicates}"
|
||||
|
||||
|
||||
def test_all_endpoints_have_operation_ids(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that all endpoints have operationIds
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
missing = []
|
||||
|
||||
for path, path_item in api_spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
if 'operationId' not in operation:
|
||||
missing.append(f"{method.upper()} {path}")
|
||||
|
||||
assert len(missing) == 0, f"Found endpoints without operationIds: {missing}"
|
||||
|
||||
|
||||
def test_all_endpoints_have_tags(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that all endpoints have tags
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
missing = []
|
||||
|
||||
for path, path_item in api_spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
if 'tags' not in operation or not operation['tags']:
|
||||
missing.append(f"{method.upper()} {path}")
|
||||
|
||||
assert len(missing) == 0, f"Found endpoints without tags: {missing}"
|
||||
157
tests-api/utils/schema_utils.py
Normal file
157
tests-api/utils/schema_utils.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Utilities for working with OpenAPI schemas
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
|
||||
def extract_required_parameters(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Extract required parameters for a specific endpoint
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
path: API path (e.g., '/prompt')
|
||||
method: HTTP method (e.g., 'get', 'post')
|
||||
|
||||
Returns:
|
||||
Tuple of (path_params, query_params) containing required parameters
|
||||
"""
|
||||
method = method.lower()
|
||||
path_params = []
|
||||
query_params = []
|
||||
|
||||
# Handle path not found
|
||||
if path not in spec['paths']:
|
||||
return path_params, query_params
|
||||
|
||||
# Handle method not found
|
||||
if method not in spec['paths'][path]:
|
||||
return path_params, query_params
|
||||
|
||||
# Get parameters
|
||||
params = spec['paths'][path][method].get('parameters', [])
|
||||
|
||||
for param in params:
|
||||
if param.get('required', False):
|
||||
if param.get('in') == 'path':
|
||||
path_params.append(param)
|
||||
elif param.get('in') == 'query':
|
||||
query_params.append(param)
|
||||
|
||||
return path_params, query_params
|
||||
|
||||
|
||||
def get_request_body_schema(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get request body schema for a specific endpoint
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
path: API path (e.g., '/prompt')
|
||||
method: HTTP method (e.g., 'get', 'post')
|
||||
|
||||
Returns:
|
||||
Request body schema or None if not found
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
# Handle path not found
|
||||
if path not in spec['paths']:
|
||||
return None
|
||||
|
||||
# Handle method not found
|
||||
if method not in spec['paths'][path]:
|
||||
return None
|
||||
|
||||
# Handle no request body
|
||||
request_body = spec['paths'][path][method].get('requestBody', {})
|
||||
if not request_body or 'content' not in request_body:
|
||||
return None
|
||||
|
||||
# Get schema from first content type
|
||||
content_types = request_body['content']
|
||||
first_content_type = next(iter(content_types))
|
||||
|
||||
if 'schema' not in content_types[first_content_type]:
|
||||
return None
|
||||
|
||||
return content_types[first_content_type]['schema']
|
||||
|
||||
|
||||
def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract all endpoints with a specific tag
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
tag: Tag to filter by
|
||||
|
||||
Returns:
|
||||
List of endpoint details
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
if tag in operation.get('tags', []):
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
|
||||
def get_all_tags(spec: Dict[str, Any]) -> Set[str]:
|
||||
"""
|
||||
Get all tags used in the API spec
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
|
||||
Returns:
|
||||
Set of tag names
|
||||
"""
|
||||
tags = set()
|
||||
|
||||
for path_item in spec['paths'].values():
|
||||
for operation in path_item.values():
|
||||
if isinstance(operation, dict) and 'tags' in operation:
|
||||
tags.update(operation['tags'])
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def get_schema_examples(spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract all examples from component schemas
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
|
||||
Returns:
|
||||
Dict mapping schema names to examples
|
||||
"""
|
||||
examples = {}
|
||||
|
||||
if 'components' not in spec or 'schemas' not in spec['components']:
|
||||
return examples
|
||||
|
||||
for name, schema in spec['components']['schemas'].items():
|
||||
if 'example' in schema:
|
||||
examples[name] = schema['example']
|
||||
|
||||
return examples
|
||||
178
tests-api/utils/validation.py
Normal file
178
tests-api/utils/validation.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Utilities for API response validation against OpenAPI spec
|
||||
"""
|
||||
import yaml
|
||||
import jsonschema
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
def load_openapi_spec(spec_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load the OpenAPI specification from a YAML file
|
||||
|
||||
Args:
|
||||
spec_path: Path to the OpenAPI specification file
|
||||
|
||||
Returns:
|
||||
Dict containing the parsed OpenAPI spec
|
||||
"""
|
||||
with open(spec_path, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def get_endpoint_schema(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str,
|
||||
status_code: str = '200'
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract response schema for a specific endpoint from OpenAPI spec
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
path: API path (e.g., '/prompt')
|
||||
method: HTTP method (e.g., 'get', 'post')
|
||||
status_code: HTTP status code to get schema for
|
||||
|
||||
Returns:
|
||||
Schema dict or None if not found
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
# Handle path not found
|
||||
if path not in spec['paths']:
|
||||
return None
|
||||
|
||||
# Handle method not found
|
||||
if method not in spec['paths'][path]:
|
||||
return None
|
||||
|
||||
# Handle status code not found
|
||||
responses = spec['paths'][path][method].get('responses', {})
|
||||
if status_code not in responses:
|
||||
return None
|
||||
|
||||
# Handle no content defined
|
||||
if 'content' not in responses[status_code]:
|
||||
return None
|
||||
|
||||
# Get schema from first content type
|
||||
content_types = responses[status_code]['content']
|
||||
first_content_type = next(iter(content_types))
|
||||
|
||||
if 'schema' not in content_types[first_content_type]:
|
||||
return None
|
||||
|
||||
return content_types[first_content_type]['schema']
|
||||
|
||||
|
||||
def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Resolve $ref references in a schema
|
||||
|
||||
Args:
|
||||
schema: Schema that may contain references
|
||||
spec: Full OpenAPI spec with component definitions
|
||||
|
||||
Returns:
|
||||
Schema with references resolved
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
|
||||
for key, value in schema.items():
|
||||
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
|
||||
# Handle reference
|
||||
ref_path = value[2:].split('/')
|
||||
ref_value = spec
|
||||
for path_part in ref_path:
|
||||
ref_value = ref_value.get(path_part, {})
|
||||
|
||||
# Recursively resolve any refs in the referenced schema
|
||||
ref_value = resolve_schema_refs(ref_value, spec)
|
||||
result.update(ref_value)
|
||||
elif isinstance(value, dict):
|
||||
# Recursively resolve refs in nested dictionaries
|
||||
result[key] = resolve_schema_refs(value, spec)
|
||||
elif isinstance(value, list):
|
||||
# Recursively resolve refs in list items
|
||||
result[key] = [
|
||||
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
# Pass through other values
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_response(
|
||||
response_data: Union[Dict[str, Any], List[Any]],
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str,
|
||||
status_code: str = '200'
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate a response against the OpenAPI schema
|
||||
|
||||
Args:
|
||||
response_data: Response data to validate
|
||||
spec: Parsed OpenAPI specification
|
||||
path: API path (e.g., '/prompt')
|
||||
method: HTTP method (e.g., 'get', 'post')
|
||||
status_code: HTTP status code to validate against
|
||||
|
||||
Returns:
|
||||
Dict with validation result containing:
|
||||
- valid: bool indicating if validation passed
|
||||
- errors: List of validation errors if any
|
||||
"""
|
||||
schema = get_endpoint_schema(spec, path, method, status_code)
|
||||
|
||||
if schema is None:
|
||||
return {
|
||||
'valid': False,
|
||||
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
|
||||
}
|
||||
|
||||
# Resolve any $ref in the schema
|
||||
resolved_schema = resolve_schema_refs(schema, spec)
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=response_data, schema=resolved_schema)
|
||||
return {'valid': True, 'errors': []}
|
||||
except jsonschema.exceptions.ValidationError as e:
|
||||
return {'valid': False, 'errors': [str(e)]}
|
||||
|
||||
|
||||
def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract all endpoints from an OpenAPI spec
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
|
||||
Returns:
|
||||
List of dicts with path, method, and tags for each endpoint
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'tags': operation.get('tags', []),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
@@ -1,240 +0,0 @@
|
||||
import torch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Mock nodes module to prevent CUDA initialization during import
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
|
||||
with patch.dict('sys.modules', {'nodes': mock_nodes}):
|
||||
from comfy_extras.nodes_images import ImageStitch
|
||||
|
||||
|
||||
class TestImageStitch:
|
||||
|
||||
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
|
||||
"""Helper to create test images with specific dimensions"""
|
||||
return torch.rand(batch_size, height, width, channels)
|
||||
|
||||
def test_no_image2_passthrough(self):
|
||||
"""Test that when image2 is None, image1 is returned unchanged"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image()
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
||||
|
||||
assert len(result) == 1
|
||||
assert torch.equal(result[0], image1)
|
||||
|
||||
def test_basic_horizontal_stitch_right(self):
|
||||
"""Test basic horizontal stitching to the right"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
|
||||
|
||||
def test_basic_horizontal_stitch_left(self):
|
||||
"""Test basic horizontal stitching to the left"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
|
||||
result = node.stitch(image1, "left", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
|
||||
|
||||
def test_basic_vertical_stitch_down(self):
|
||||
"""Test basic vertical stitching downward"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
|
||||
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
|
||||
|
||||
def test_basic_vertical_stitch_up(self):
|
||||
"""Test basic vertical stitching upward"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
|
||||
result = node.stitch(image1, "up", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
|
||||
|
||||
def test_size_matching_horizontal(self):
|
||||
"""Test size matching for horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=64)
|
||||
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2)
|
||||
|
||||
# image2 should be resized to match image1's height (64) with preserved aspect ratio
|
||||
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
|
||||
assert result[0].shape == (1, 64, expected_width, 3)
|
||||
|
||||
def test_size_matching_vertical(self):
|
||||
"""Test size matching for vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=64)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
result = node.stitch(image1, "down", True, 0, "white", image2)
|
||||
|
||||
# image2 should be resized to match image1's width (64) with preserved aspect ratio
|
||||
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
|
||||
assert result[0].shape == (1, expected_height, 64, 3)
|
||||
|
||||
def test_padding_for_mismatched_heights_horizontal(self):
|
||||
"""Test padding when heights don't match in horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=32)
|
||||
image2 = self.create_test_image(height=48, width=24) # Shorter height
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Both images should be padded to height 64
|
||||
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
|
||||
|
||||
def test_padding_for_mismatched_widths_vertical(self):
|
||||
"""Test padding when widths don't match in vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=64)
|
||||
image2 = self.create_test_image(height=24, width=48) # Narrower width
|
||||
|
||||
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||
|
||||
# Both images should be padded to width 64
|
||||
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
|
||||
|
||||
def test_spacing_horizontal(self):
|
||||
"""Test spacing addition in horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
spacing_width = 16
|
||||
|
||||
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
|
||||
|
||||
# Expected width: 32 + 16 (spacing) + 24 = 72
|
||||
assert result[0].shape == (1, 32, 72, 3)
|
||||
|
||||
def test_spacing_vertical(self):
|
||||
"""Test spacing addition in vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
spacing_width = 16
|
||||
|
||||
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
|
||||
|
||||
# Expected height: 32 + 16 (spacing) + 24 = 72
|
||||
assert result[0].shape == (1, 72, 32, 3)
|
||||
|
||||
def test_spacing_color_values(self):
|
||||
"""Test that spacing colors are applied correctly"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
# Test white spacing
|
||||
result_white = node.stitch(image1, "right", False, 16, "white", image2)
|
||||
# Check that spacing region contains white values (close to 1.0)
|
||||
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
|
||||
assert torch.all(spacing_region >= 0.9) # Should be close to white
|
||||
|
||||
# Test black spacing
|
||||
result_black = node.stitch(image1, "right", False, 16, "black", image2)
|
||||
spacing_region = result_black[0][:, :, 32:48, :]
|
||||
assert torch.all(spacing_region <= 0.1) # Should be close to black
|
||||
|
||||
def test_odd_spacing_width_made_even(self):
|
||||
"""Test that odd spacing widths are made even"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
# Use odd spacing width
|
||||
result = node.stitch(image1, "right", False, 15, "white", image2)
|
||||
|
||||
# Should be made even (16), so total width = 32 + 16 + 32 = 80
|
||||
assert result[0].shape == (1, 32, 80, 3)
|
||||
|
||||
def test_batch_size_matching(self):
|
||||
"""Test that different batch sizes are handled correctly"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(batch_size=2, height=32, width=32)
|
||||
image2 = self.create_test_image(batch_size=1, height=32, width=32)
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should match larger batch size
|
||||
assert result[0].shape == (2, 32, 64, 3)
|
||||
|
||||
def test_channel_matching_rgb_to_rgba(self):
|
||||
"""Test that channel differences are handled (RGB + alpha)"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(channels=3) # RGB
|
||||
image2 = self.create_test_image(channels=4) # RGBA
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should have 4 channels (RGBA)
|
||||
assert result[0].shape[-1] == 4
|
||||
|
||||
def test_channel_matching_rgba_to_rgb(self):
|
||||
"""Test that channel differences are handled (RGBA + RGB)"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(channels=4) # RGBA
|
||||
image2 = self.create_test_image(channels=3) # RGB
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should have 4 channels (RGBA)
|
||||
assert result[0].shape[-1] == 4
|
||||
|
||||
def test_all_color_options(self):
|
||||
"""Test all available color options"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
colors = ["white", "black", "red", "green", "blue"]
|
||||
|
||||
for color in colors:
|
||||
result = node.stitch(image1, "right", False, 16, color, image2)
|
||||
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
|
||||
|
||||
def test_all_directions(self):
|
||||
"""Test all direction options"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
directions = ["right", "left", "up", "down"]
|
||||
|
||||
for direction in directions:
|
||||
result = node.stitch(image1, direction, False, 0, "white", image2)
|
||||
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
|
||||
|
||||
def test_batch_size_channel_spacing_integration(self):
|
||||
"""Test integration of batch matching, channel matching, size matching, and spacings"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
|
||||
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
|
||||
|
||||
result = node.stitch(image1, "right", True, 8, "red", image2)
|
||||
|
||||
# Should handle: batch matching, size matching, channel matching, spacing
|
||||
assert result[0].shape[0] == 2 # Batch size matched
|
||||
assert result[0].shape[-1] == 4 # Channels matched to max
|
||||
assert result[0].shape[1] == 64 # Height from image1 (size matching)
|
||||
# Width should be: 48 + 8 (spacing) + resized_image2_width
|
||||
expected_image2_width = int(64 * (32/32)) # Resized to height 64
|
||||
expected_total_width = 48 + 8 + expected_image2_width
|
||||
assert result[0].shape[2] == expected_total_width
|
||||
|
||||
Reference in New Issue
Block a user