Compare commits

..

3 Commits

Author SHA1 Message Date
bymyself
d65ad9940b add swagger validator workflow 2025-05-21 21:15:41 -07:00
bymyself
e8a92e4c9b Fix linting issues in API tests 2025-05-20 12:26:56 -07:00
bymyself
fa9688b1fb [docs] Add OpenAPI specification and test framework 2025-05-20 12:15:46 -07:00
60 changed files with 5914 additions and 7837 deletions

View File

@@ -0,0 +1,49 @@
name: Validate OpenAPI
on:
push:
branches: [ master ]
paths:
- 'openapi.yaml'
pull_request:
branches: [ master ]
paths:
- 'openapi.yaml'
jobs:
openapi-check:
runs-on: ubuntu-latest
# Service containers to run with `runner-job`
services:
# Label used to access the service container
swagger-editor:
# Docker Hub image
image: swaggerapi/swagger-editor
ports:
# Maps port 8080 on service container to the host 80
- 80:8080
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Validate OpenAPI definition
uses: swaggerexpert/swagger-editor-validate@v1
with:
definition-file: openapi.yaml
swagger-editor-url: http://localhost/
default-timeout: 20000
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install test dependencies
run: |
pip install -r tests-api/requirements.txt
- name: Run OpenAPI spec validation tests
run: |
pytest tests-api/test_spec_validation.py -v

1
.gitignore vendored
View File

@@ -21,6 +21,5 @@ venv/
*.log
web_custom_versions/
.DS_Store
openapi.yaml
filtered-openapi.yaml
uv.lock

View File

@@ -5,20 +5,20 @@
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
# Python web server
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne

View File

@@ -205,19 +205,6 @@ comfyui-workflow-templates is not installed.
""".strip()
)
@classmethod
def embedded_docs_path(cls) -> str:
"""Get the path to embedded documentation"""
try:
import comfyui_embedded_docs
return str(
importlib.resources.files(comfyui_embedded_docs) / "docs"
)
except ImportError:
logging.info("comfyui-embedded-docs package not found")
return None
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""

View File

@@ -88,7 +88,6 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"

View File

@@ -24,10 +24,6 @@ class CONDRegular:
conds.append(x.cond)
return torch.cat(conds)
def size(self):
return list(self.cond.size())
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond
@@ -68,7 +64,6 @@ class CONDCrossAttn(CONDRegular):
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
@@ -83,48 +78,3 @@ class CONDConstant(CONDRegular):
def concat(self, others):
return self.cond
def size(self):
return [1]
class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
return self._copy_with(out)
def can_concat(self, other):
if len(self.cond) != len(other.cond):
return False
for i in range(len(self.cond)):
if self.cond[i].shape != other.cond[i].shape:
return False
return True
def concat(self, others):
out = []
for i in range(len(self.cond)):
o = [self.cond[i]]
for x in others:
o.append(x.cond[i])
out.append(torch.cat(o))
return out
def size(self): # hackish implementation to make the mem estimation work
o = 0
c = 1
for c in self.cond:
size = c.size()
o += math.prod(size)
if len(size) > 1:
c = size[1]
return [1, c, o // c]

View File

@@ -80,13 +80,15 @@ class DoubleStreamBlock(nn.Module):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -100,12 +102,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -150,7 +152,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -160,7 +162,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
@@ -176,6 +178,6 @@ class LastLayer(nn.Module):
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

View File

@@ -163,7 +163,7 @@ class Chroma(nn.Module):
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
# get all modulation index
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
# and we need to broadcast timestep and guidance along too

View File

@@ -20,11 +20,8 @@ if model_management.xformers_enabled():
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError as e:
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1)
if model_management.flash_attention_enabled():

View File

@@ -539,20 +539,13 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
@@ -642,7 +635,7 @@ class VaceWanModel(WanModel):
t,
context,
vace_context,
vace_strength,
vace_strength=1.0,
clip_fea=None,
freqs=None,
transformer_options={},
@@ -668,11 +661,8 @@ class VaceWanModel(WanModel):
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
orig_shape = list(vace_context.shape)
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
c = c.flatten(2).transpose(1, 2)
c = list(c.split(orig_shape[0], dim=0))
# arguments
x_orig = x
@@ -692,9 +682,8 @@ class VaceWanModel(WanModel):
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength[iii]
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength
del c_skip
# head
x = self.head(x, e)

View File

@@ -283,9 +283,8 @@ def model_lora_keys_unet(model, key_map={}):
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.ACEStep):
for k in sdk:

View File

@@ -135,7 +135,6 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -168,11 +167,6 @@ class BaseModel(torch.nn.Module):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
if isinstance(extra, list):
ex = []
for ext in extra:
ex.append(ext.to(dtype))
extra = ex
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
@@ -331,28 +325,19 @@ class BaseModel(torch.nn.Module):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
def memory_required(self, input_shape, cond_shapes={}):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
if shape is not None and len(shape) > 0:
input_shapes += shape
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
area = input_shape[0] * math.prod(input_shape[2:])
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
area = input_shape[0] * math.prod(input_shape[2:])
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
def extra_conds_shapes(self, **kwargs):
return {}
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = []
@@ -1062,11 +1047,6 @@ class WAN21(BaseModel):
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
time_dim_concat = kwargs.get("time_dim_concat", None)
if time_dim_concat is not None:
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
return out
@@ -1082,25 +1062,20 @@ class WAN21_Vace(WAN21):
vace_frames = kwargs.get("vace_frames", None)
if vace_frames is None:
noise_shape[1] = 32
vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
for i in range(0, vace_frames.shape[1], 16):
vace_frames = vace_frames.clone()
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
mask = kwargs.get("vace_mask", None)
if mask is None:
noise_shape[1] = 64
mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
vace_frames_out = []
for j in range(len(vace_frames)):
vf = vace_frames[j].clone()
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
vf = torch.cat([vf, mask[j]], dim=1)
vace_frames_out.append(vf)
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
vace_frames = torch.stack(vace_frames_out, dim=1)
out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
vace_strength = kwargs.get("vace_strength", 1.0)
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out

View File

@@ -620,9 +620,6 @@ def convert_config(unet_config):
def unet_config_from_diffusers_unet(state_dict, dtype=None):
if "conv_in.weight" not in state_dict:
return None
match = {}
transformer_depth = []

View File

@@ -297,16 +297,11 @@ except:
try:
if is_amd():
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
except:
pass
@@ -700,7 +695,7 @@ def unet_inital_load_device(parameters, dtype):
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
if DISABLE_SMART_MEMORY:
return cpu_dev
model_size = dtype_size(dtype) * parameters
@@ -1262,9 +1257,6 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
if args.supports_fp8_compute:
return True
if not is_nvidia():
return False

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.utils
@@ -106,21 +104,6 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
for _, cs in conds.items():
for cond in cs:
for k, v in model.model.extra_conds_shapes(**cond).items():
cond_shapes[k].append(v)
if cond_shapes_min.get(k, None) is None:
cond_shapes_min[k] = [v]
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
cond_shapes_min[k] = [v]
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
@@ -134,8 +117,9 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
real_model = model.model
return real_model, conds, models

View File

@@ -256,13 +256,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break

View File

@@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@@ -1,5 +0,0 @@
from .torch_compile import set_torch_compile_wrapper
__all__ = [
"set_torch_compile_wrapper",
]

View File

@@ -1,69 +0,0 @@
from __future__ import annotations
import torch
import comfy.utils
from comfy.patcher_extension import WrappersMP
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.patcher_extension import WrapperExecutor
COMPILE_KEY = "torch.compile"
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
'''
Create a wrapper that will refer to the compiled_diffusion_model.
'''
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
try:
orig_modules = {}
for key, value in compiled_module_dict.items():
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
comfy.utils.set_attr(executor.class_obj, key, value)
return executor(*args, **kwargs)
finally:
for key, value in orig_modules.items():
comfy.utils.set_attr(executor.class_obj, key, value)
return apply_torch_compile_wrapper
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
keys: list[str]=["diffusion_model"], *args, **kwargs):
'''
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
When a list of keys is provided, it will perform torch.compile on only the selected modules.
'''
# clear out any other torch.compile wrappers
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
# if no keys, default to 'diffusion_model'
if not keys:
keys = ["diffusion_model"]
# create kwargs dict that can be referenced later
compile_kwargs = {
"backend": backend,
"options": options,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
# get a dict of compiled keys
compiled_modules = {}
for key in keys:
compiled_modules[key] = torch.compile(
model=model.get_model_object(key),
**compile_kwargs,
)
# add torch.compile wrapper
wrapper_func = apply_torch_compile_factory(
compiled_module_dict=compiled_modules,
)
# store wrapper to run on BaseModel's apply_model function
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
# keep compile kwargs for reference
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs

View File

@@ -1,855 +0,0 @@
from __future__ import annotations
from typing import Any, Literal
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from comfy.comfy_types.node_typing import IO
class InputBehavior(str, Enum):
required = "required"
optional = "optional"
def is_class(obj):
'''
Returns True if is a class type.
Returns False if is a class instance.
'''
return isinstance(obj, type)
class NumberDisplay(str, Enum):
number = "number"
slider = "slider"
class IO_V3:
'''
Base class for V3 Inputs and Outputs.
'''
def __init__(self):
pass
def __init_subclass__(cls, io_type: IO | str, **kwargs):
cls.io_type = io_type
super().__init_subclass__(**kwargs)
class InputV3(IO_V3, io_type=None):
'''
Base class for a V3 Input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None):
super().__init__()
self.id = id
self.display_name = display_name
self.behavior = behavior
self.tooltip = tooltip
self.lazy = lazy
def as_dict_V1(self):
return prune_dict({
"display_name": self.display_name,
"tooltip": self.tooltip,
"lazy": self.lazy
})
def get_io_type_V1(self):
return self.io_type
class WidgetInputV3(InputV3, io_type=None):
'''
Base class for a V3 Input with widget.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: Any=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy)
self.default = default
self.socketless = socketless
self.widgetType = widgetType
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"default": self.default,
"socketless": self.socketless,
"widgetType": self.widgetType,
})
def CustomType(io_type: IO | str) -> type[IO_V3]:
name = f"{io_type}_IO_V3"
return type(name, (IO_V3,), {}, io_type=io_type)
def CustomInput(id: str, io_type: IO | str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3:
'''
Defines input for 'io_type'. Can be used to stand in for non-core types.
'''
input_kwargs = {
"id": id,
"display_name": display_name,
"behavior": behavior,
"tooltip": tooltip,
"lazy": lazy,
}
return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs)
def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: str=None) -> OutputV3:
'''
Defines output for 'io_type'. Can be used to stand in for non-core types.
'''
input_kwargs = {
"id": id,
"display_name": display_name,
"tooltip": tooltip,
}
return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs)
class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
'''
Boolean input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.label_on = label_on
self.label_off = label_off
self.default: bool
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"label_on": self.label_on,
"label_off": self.label_off,
})
class IntegerInput(WidgetInputV3, io_type=IO.INT):
'''
Integer input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.min = min
self.max = max
self.step = step
self.control_after_generate = control_after_generate
self.display_mode = display_mode
self.default: int
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"min": self.min,
"max": self.max,
"step": self.step,
"control_after_generate": self.control_after_generate,
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
})
class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
'''
Float input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.default = default
self.min = min
self.max = max
self.step = step
self.round = round
self.display_mode = display_mode
self.default: float
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"min": self.min,
"max": self.max,
"step": self.step,
"round": self.round,
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
})
class StringInput(WidgetInputV3, io_type=IO.STRING):
'''
String input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: int=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.multiline = multiline
self.placeholder = placeholder
self.default: str
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiline": self.multiline,
"placeholder": self.placeholder,
})
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
'''Combo input (dropdown).'''
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.multiselect = False
self.options = options
self.control_after_generate = control_after_generate
self.default: str
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"options": self.options,
"control_after_generate": self.control_after_generate,
})
class MultiselectComboWidget(ComboInput, io_type=IO.COMBO):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, options, display_name, behavior, tooltip, lazy, default, control_after_generate, socketless, widgetType)
self.multiselect = True
self.placeholder = placeholder
self.chip = chip
self.default: list[str]
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"placeholder": self.placeholder,
"chip": self.chip,
})
class ImageInput(InputV3, io_type=IO.IMAGE):
'''
Image input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip)
class MaskInput(InputV3, io_type=IO.MASK):
'''
Mask input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip)
class LatentInput(InputV3, io_type=IO.LATENT):
'''
Latent input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip)
class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"):
'''
Input that permits more than one input type.
'''
def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,):
super().__init__(id, display_name, behavior, tooltip)
self._io_types = io_types
@property
def io_types(self) -> list[type[InputV3]]:
'''
Returns list of InputV3 class types permitted.
'''
io_types = []
for x in self._io_types:
if not is_class(x):
io_types.append(type(x))
else:
io_types.append(x)
return io_types
def get_io_type_V1(self):
return ",".join(x.io_type for x in self.io_types)
class OutputV3:
def __init__(self, id: str, display_name: str=None, tooltip: str=None,
is_output_list=False):
self.id = id
self.display_name = display_name
self.tooltip = tooltip
self.is_output_list = is_output_list
def __init_subclass__(cls, io_type, **kwargs):
cls.io_type = io_type
super().__init_subclass__(**kwargs)
class IntegerOutput(OutputV3, io_type=IO.INT):
pass
class FloatOutput(OutputV3, io_type=IO.FLOAT):
pass
class StringOutput(OutputV3, io_type=IO.STRING):
pass
# def __init__(self, id: str, display_name: str=None, tooltip: str=None):
# super().__init__(id, display_name, tooltip)
class ImageOutput(OutputV3, io_type=IO.IMAGE):
pass
class MaskOutput(OutputV3, io_type=IO.MASK):
pass
class LatentOutput(OutputV3, io_type=IO.LATENT):
pass
class DynamicInput(InputV3, io_type=None):
'''
Abstract class for dynamic input registration.
'''
def __init__(self, io_type: str, id: str, display_name: str=None):
super().__init__(io_type, id, display_name)
class DynamicOutput(OutputV3, io_type=None):
'''
Abstract class for dynamic output registration.
'''
def __init__(self, io_type: str, id: str, display_name: str=None):
super().__init__(io_type, id, display_name)
class AutoGrowDynamicInput(DynamicInput, io_type="COMFY_MULTIGROW_V3"):
'''
Dynamic Input that adds another template_input each time one is provided.
Additional inputs are forced to have 'InputBehavior.optional'.
'''
def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None):
super().__init__("AutoGrowDynamicInput", id)
self.template_input = template_input
if min is not None:
assert(min >= 1)
if max is not None:
assert(max >= 1)
self.min = min
self.max = max
class ComboDynamicInput(DynamicInput, io_type="COMFY_COMBODYNAMIC_V3"):
def __init__(self, id: str):
pass
AutoGrowDynamicInput(id="dynamic", template_input=ImageInput(id="image"))
class Hidden(str, Enum):
'''
Enumerator for requesting hidden variables in nodes.
'''
unique_id = "UNIQUE_ID"
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt = "PROMPT"
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo = "EXTRA_PNGINFO"
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt = "DYNPROMPT"
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
auth_token_comfy_org = "AUTH_TOKEN_COMFY_ORG"
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
api_key_comfy_org = "API_KEY_COMFY_ORG"
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
@dataclass
class NodeInfoV1:
input: dict=None
input_order: dict[str, list[str]]=None
output: list[str]=None
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
name: str=None
display_name: str=None
description: str=None
python_module: Any=None
category: str=None
output_node: bool=None
deprecated: bool=None
experimental: bool=None
api_node: bool=None
def as_pruned_dict(dataclass_obj):
'''Return dict of dataclass object with pruned None values.'''
return prune_dict(asdict(dataclass_obj))
def prune_dict(d: dict):
return {k: v for k,v in d.items() if v is not None}
@dataclass
class SchemaV3:
"""Definition of V3 node properties."""
node_id: str
"""ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
display_name: str = None
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
inputs: list[InputV3]=None
outputs: list[OutputV3]=None
hidden: list[Hidden]=None
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
From the docs:
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
is_output_node: bool=False
"""Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage::
OUTPUT_NODE = True
From the docs:
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
"""
is_deprecated: bool=False
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
is_experimental: bool=False
"""Flags a node as experimental, informing users that it may change or not work as expected."""
is_api_node: bool=False
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
# class SchemaV3Class:
# def __init__(self,
# node_id: str,
# node_name: str,
# category: str,
# inputs: list[InputV3],
# outputs: list[OutputV3]=None,
# hidden: list[Hidden]=None,
# description: str="",
# is_input_list: bool = False,
# is_output_node: bool=False,
# is_deprecated: bool=False,
# is_experimental: bool=False,
# is_api_node: bool=False,
# ):
# self.node_id = node_id
# """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
# self.node_name = node_name
# """Display name of node."""
# self.category = category
# """The category of the node, as per the "Add Node" menu."""
# self.inputs = inputs
# self.outputs = outputs
# self.hidden = hidden
# self.description = description
# """Node description, shown as a tooltip when hovering over the node."""
# self.is_input_list = is_input_list
# """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
# All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
# From the docs:
# A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
# """
# self.is_output_node = is_output_node
# """Flags this node as an output node, causing any inputs it requires to be executed.
# If a node is not connected to any output nodes, that node will not be executed. Usage::
# OUTPUT_NODE = True
# From the docs:
# By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
# """
# self.is_deprecated = is_deprecated
# """Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
# self.is_experimental = is_experimental
# """Flags a node as experimental, informing users that it may change or not work as expected."""
# self.is_api_node = is_api_node
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
class ComfyNodeV3(ABC):
"""Common base class for all V3 nodes."""
RELATIVE_PYTHON_MODULE = None
#############################################
# V1 Backwards Compatibility code
#--------------------------------------------
_DESCRIPTION = None
@classproperty
def DESCRIPTION(cls):
if cls._DESCRIPTION is None:
cls.GET_SCHEMA()
return cls._DESCRIPTION
_CATEGORY = None
@classproperty
def CATEGORY(cls):
if cls._CATEGORY is None:
cls.GET_SCHEMA()
return cls._CATEGORY
_EXPERIMENTAL = None
@classproperty
def EXPERIMENTAL(cls):
if cls._EXPERIMENTAL is None:
cls.GET_SCHEMA()
return cls._EXPERIMENTAL
_DEPRECATED = None
@classproperty
def DEPRECATED(cls):
if cls._DEPRECATED is None:
cls.GET_SCHEMA()
return cls._DEPRECATED
_API_NODE = None
@classproperty
def API_NODE(cls):
if cls._API_NODE is None:
cls.GET_SCHEMA()
return cls._API_NODE
_OUTPUT_NODE = None
@classproperty
def OUTPUT_NODE(cls):
if cls._OUTPUT_NODE is None:
cls.GET_SCHEMA()
return cls._OUTPUT_NODE
_INPUT_IS_LIST = None
@classproperty
def INPUT_IS_LIST(cls):
if cls._INPUT_IS_LIST is None:
cls.GET_SCHEMA()
return cls._INPUT_IS_LIST
_OUTPUT_IS_LIST = None
@classproperty
def OUTPUT_IS_LIST(cls):
if cls._OUTPUT_IS_LIST is None:
cls.GET_SCHEMA()
return cls._OUTPUT_IS_LIST
_RETURN_TYPES = None
@classproperty
def RETURN_TYPES(cls):
if cls._RETURN_TYPES is None:
cls.GET_SCHEMA()
return cls._RETURN_TYPES
_RETURN_NAMES = None
@classproperty
def RETURN_NAMES(cls):
if cls._RETURN_NAMES is None:
cls.GET_SCHEMA()
return cls._RETURN_NAMES
_OUTPUT_TOOLTIPS = None
@classproperty
def OUTPUT_TOOLTIPS(cls):
if cls._OUTPUT_TOOLTIPS is None:
cls.GET_SCHEMA()
return cls._OUTPUT_TOOLTIPS
FUNCTION = "execute"
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
schema = cls.DEFINE_SCHEMA()
# for V1, make inputs be a dict with potential keys {required, optional, hidden}
input = {
"required": {}
}
if schema.inputs:
for i in schema.inputs:
input.setdefault(i.behavior.value, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1())
if schema.hidden:
for hidden in schema.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
return input
@classmethod
def GET_SCHEMA(cls) -> SchemaV3:
schema = cls.DEFINE_SCHEMA()
if cls._DESCRIPTION is None:
cls._DESCRIPTION = schema.description
if cls._CATEGORY is None:
cls._CATEGORY = schema.category
if cls._EXPERIMENTAL is None:
cls._EXPERIMENTAL = schema.is_experimental
if cls._DEPRECATED is None:
cls._DEPRECATED = schema.is_deprecated
if cls._API_NODE is None:
cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None:
cls._OUTPUT_NODE = schema.is_output_node
if cls._INPUT_IS_LIST is None:
cls._INPUT_IS_LIST = schema.is_input_list
if cls._RETURN_TYPES is None:
output = []
output_name = []
output_is_list = []
output_tooltips = []
if schema.outputs:
for o in schema.outputs:
output.append(o.io_type)
output_name.append(o.display_name if o.display_name else o.io_type)
output_is_list.append(o.is_output_list)
output_tooltips.append(o.tooltip if o.tooltip else None)
cls._RETURN_TYPES = output
cls._RETURN_NAMES = output_name
cls._OUTPUT_IS_LIST = output_is_list
cls._OUTPUT_TOOLTIPS = output_tooltips
return schema
@classmethod
def GET_NODE_INFO_V1(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# get V1 inputs
input = cls.INPUT_TYPES()
# create separate lists from output fields
output = []
output_is_list = []
output_name = []
output_tooltips = []
if schema.outputs:
for o in schema.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
info = NodeInfoV1(
input=input,
input_order={key: list(value.keys()) for (key, value) in input.items()},
output=output,
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
name=schema.node_id,
display_name=schema.display_name,
category=schema.category,
description=schema.description,
output_node=schema.is_output_node,
deprecated=schema.is_deprecated,
experimental=schema.is_experimental,
api_node=schema.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
)
return asdict(info)
#--------------------------------------------
#############################################
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
@classmethod
@abstractmethod
def DEFINE_SCHEMA(cls) -> SchemaV3:
"""
Override this function with one that returns a SchemaV3 instance.
"""
return None
DEFINE_SCHEMA = None
def __init__(self):
if self.DEFINE_SCHEMA is None:
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
@abstractmethod
def execute(self, **kwargs) -> NodeOutput:
pass
# class ReturnedInputs:
# def __init__(self):
# pass
# class ReturnedOutputs:
# def __init__(self):
# pass
class NodeOutput:
'''
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
'''
def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
self.args = args
self.ui = ui
self.expand = expand
self.block_execution = block_execution
@property
def result(self):
return self.args if len(self.args) > 0 else None
class SavedResult:
def __init__(self, filename: str, subfolder: str, type: Literal["input", "output", "temp"]):
self.filename = filename
self.subfolder = subfolder
self.type = type
def as_dict(self):
return {
"filename": self.filename,
"subfolder": self.subfolder,
"type": self.type
}
class UIOutput(ABC):
def __init__(self):
pass
@abstractmethod
def as_dict(self) -> dict:
... # TODO: finish
class UIImages(UIOutput):
def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs):
self.values = values
self.animated = animated
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"images": values,
"animated": (self.animated,)
}
class UILatents(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"latents": values,
}
class UIAudio(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"audio": values,
}
class UI3D(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"3d": values,
}
class UIText(UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
class TestNode(ComfyNodeV3):
SCHEMA = SchemaV3(
node_id="TestNode_v3",
display_name="Test Node (V3)",
category="v3_test",
inputs=[IntegerInput("my_int"),
#AutoGrowDynamicInput("growing", ImageInput),
MaskInput("thing"),
],
outputs=[ImageOutput("image_output")],
hidden=[Hidden.api_key_comfy_org, Hidden.auth_token_comfy_org, Hidden.unique_id]
)
# @classmethod
# def GET_SCHEMA(cls):
# return cls.SCHEMA
@classmethod
def DEFINE_SCHEMA(cls):
return cls.SCHEMA
def execute(**kwargs):
pass
if __name__ == "__main__":
print("hello there")
inputs: list[InputV3] = [
IntegerInput("my_int"),
CustomInput("xyz", "XYZ"),
CustomInput("model1", "MODEL_M"),
ImageInput("my_image"),
FloatInput("my_float"),
MultitypedInput("my_inputs", [CustomType("MODEL_M"), CustomType("XYZ")]),
]
outputs: list[OutputV3] = [
ImageOutput("image"),
CustomOutput("xyz", "XYZ")
]
for c in inputs:
if isinstance(c, MultitypedInput):
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}")
print(c.get_io_type_V1())
else:
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
for c in outputs:
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
zz = TestNode()
print(zz.GET_NODE_INFO_V1())
# aa = NodeInfoV1()
# print(asdict(aa))
# print(as_pruned_dict(aa))

View File

@@ -18,8 +18,6 @@ Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to
python run main.py --comfy-api-base https://stagingapi.comfy.org
```
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
### Redocly Instructions
@@ -30,7 +28,7 @@ When developing locally, use the `redocly-dev.yaml` file to generate pydantic mo
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
```bash
# Download the OpenAPI file from staging server.
# Download the OpenAPI file from prod server.
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
# Filter out unneeded API definitions.
@@ -41,25 +39,3 @@ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_no
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```
# Merging to Master
Before merging to comfyanonymous/ComfyUI master, follow these steps:
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
1. Make sure the ComfyUI API is deployed to prod with your changes.
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
```bash
# Download the OpenAPI file from prod server.
curl -o openapi.yaml https://api.comfy.org/openapi
# Filter out unneeded API definitions.
npm install -g @redocly/cli
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
# Generate the pydantic datamodels for validation.
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import io
import logging
import mimetypes
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
@@ -215,7 +214,6 @@ def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
image_bytesio = download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response: requests.Response) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response.content))
@@ -320,27 +318,11 @@ def tensor_to_data_uri(
return f"data:{mime_type};base64,{base64_string}"
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: str,
auth_kwargs: Optional[dict[str, str]] = None,
auth_kwargs: Optional[dict[str,str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
@@ -375,33 +357,9 @@ def upload_file_to_comfyapi(
return response.download_url
def video_to_base64_string(
video: VideoInput,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = io.BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None,
auth_kwargs: Optional[dict[str,str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
@@ -503,7 +461,7 @@ def audio_ndarray_to_bytesio(
def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None,
auth_kwargs: Optional[dict[str,str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
@@ -530,25 +488,8 @@ def upload_audio_to_comfyapi(
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
@@ -613,24 +554,17 @@ def upload_images_to_comfyapi(
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
upscale_method="nearest-exact", crop="disabled",
allow_gradient=True, add_channel_dim=False):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
mask = mask.movedim(-1,1)
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1,-1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
@@ -638,41 +572,12 @@ def resize_mask_to_image(
return mask
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def image_tensor_pair_to_batch(
image1: torch.Tensor, image2: torch.Tensor
) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
if not string:
raise Exception(f"Field '{field_name}' cannot be empty.")

File diff suppressed because it is too large Load Diff

View File

@@ -108,24 +108,6 @@ class BFLFluxProGenerateRequest(BaseModel):
# )
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(

View File

@@ -139,7 +139,7 @@ class EmptyRequest(BaseModel):
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field(
content_type: str | None = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
@@ -327,9 +327,7 @@ class ApiClient:
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
# Use urljoin but ensure path is relative to avoid absolute path behavior
relative_path = path.lstrip('/')
url = urljoin(self.base_url, relative_path)
url = urljoin(self.base_url, path)
self.check_auth(self.auth_token, self.comfy_api_key)
# Combine default headers with any provided headers
request_headers = self.get_headers()

View File

@@ -1,57 +0,0 @@
from __future__ import annotations
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel, Field
class Rodin3DGenerateRequest(BaseModel):
seed: int = Field(..., description="seed_")
tier: str = Field(..., description="Tier of generation.")
material: str = Field(..., description="The material type.")
quality: str = Field(..., description="The generation quality of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
class GenerateJobsData(BaseModel):
uuids: List[str] = Field(..., description="str LIST")
subscription_key: str = Field(..., description="subscription key")
class Rodin3DGenerateResponse(BaseModel):
message: Optional[str] = Field(None, description="Return message.")
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
submit_time: Optional[str] = Field(None, description="Submit Time")
uuid: Optional[str] = Field(None, description="Task str")
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
class JobStatus(str, Enum):
"""
Status for jobs
"""
Done = "Done"
Failed = "Failed"
Generating = "Generating"
Waiting = "Waiting"
class Rodin3DCheckStatusRequest(BaseModel):
subscription_key: str = Field(..., description="subscription from generate endpoint")
class JobItem(BaseModel):
uuid: str = Field(..., description="uuid")
status: JobStatus = Field(...,description="Status Currently")
class Rodin3DCheckStatusResponse(BaseModel):
jobs: List[JobItem] = Field(..., description="Job status List")
class Rodin3DDownloadRequest(BaseModel):
task_uuid: str = Field(..., description="Task str")
class RodinResourceItem(BaseModel):
url: str = Field(..., description="Download Url")
name: str = Field(..., description="File name with ext")
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")

View File

@@ -1,275 +0,0 @@
from __future__ import annotations
from comfy_api_nodes.apis import (
TripoModelVersion,
TripoTextureQuality,
)
from enum import Enum
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoStyle(str, Enum):
PERSON_TO_CARTOON = "person:person2cartoon"
ANIMAL_VENOM = "animal:venom"
OBJECT_CLAY = "object:clay"
OBJECT_STEAMPUNK = "object:steampunk"
OBJECT_CHRISTMAS = "object:christmas"
OBJECT_BARBIE = "object:barbie"
GOLD = "gold"
ANCIENT_BRONZE = "ancient_bronze"
NONE = "None"
class TripoTaskType(str, Enum):
TEXT_TO_MODEL = "text_to_model"
IMAGE_TO_MODEL = "image_to_model"
MULTIVIEW_TO_MODEL = "multiview_to_model"
TEXTURE_MODEL = "texture_model"
REFINE_MODEL = "refine_model"
ANIMATE_PRERIGCHECK = "animate_prerigcheck"
ANIMATE_RIG = "animate_rig"
ANIMATE_RETARGET = "animate_retarget"
STYLIZE_MODEL = "stylize_model"
CONVERT_MODEL = "convert_model"
class TripoTextureAlignment(str, Enum):
ORIGINAL_IMAGE = "original_image"
GEOMETRY = "geometry"
class TripoOrientation(str, Enum):
ALIGN_IMAGE = "align_image"
DEFAULT = "default"
class TripoOutFormat(str, Enum):
GLB = "glb"
FBX = "fbx"
class TripoTopology(str, Enum):
BIP = "bip"
QUAD = "quad"
class TripoSpec(str, Enum):
MIXAMO = "mixamo"
TRIPO = "tripo"
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
CLIMB = "preset:climb"
JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash"
SHOOT = "preset:shoot"
HURT = "preset:hurt"
FALL = "preset:fall"
TURN = "preset:turn"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
VOXEL = "voxel"
VORONOI = "voronoi"
MINECRAFT = "minecraft"
class TripoConvertFormat(str, Enum):
GLTF = "GLTF"
USDZ = "USDZ"
FBX = "FBX"
OBJ = "OBJ"
STL = "STL"
_3MF = "3MF"
class TripoTextureFormat(str, Enum):
BMP = "BMP"
DPX = "DPX"
HDR = "HDR"
JPEG = "JPEG"
OPEN_EXR = "OPEN_EXR"
PNG = "PNG"
TARGA = "TARGA"
TIFF = "TIFF"
WEBP = "WEBP"
class TripoTaskStatus(str, Enum):
QUEUED = "queued"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
CANCELLED = "cancelled"
UNKNOWN = "unknown"
BANNED = "banned"
EXPIRED = "expired"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
file_token: str
class TripoUrlReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
url: str
class TripoObjectStorage(BaseModel):
bucket: str
key: str
class TripoObjectReference(BaseModel):
type: str
object: TripoObjectStorage
class TripoFileEmptyReference(BaseModel):
pass
class TripoFileReference(RootModel):
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
class TripoGetStsTokenRequest(BaseModel):
format: str = Field(..., description='The format of the image')
class TripoTextToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
image_seed: Optional[int] = Field(None, description='The seed for the text')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoImageToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoMultiviewToModelRequest(BaseModel):
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoTextureModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
class TripoRefineModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
class TripoAnimatePrerigcheckRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
class TripoAnimateRigRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
class TripoAnimateRetargetRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
animation: TripoAnimation = Field(..., description='The animation to apply')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
class TripoStylizeModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
original_model_task_id: str = Field(..., description='The task ID of the original model')
block_size: Optional[int] = Field(80, description='The block size for stylization')
class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
class TripoTaskRequest(RootModel):
root: Union[
TripoTextToModelRequest,
TripoImageToModelRequest,
TripoMultiviewToModelRequest,
TripoTextureModelRequest,
TripoRefineModelRequest,
TripoAnimatePrerigcheckRequest,
TripoAnimateRigRequest,
TripoAnimateRetargetRequest,
TripoStylizeModelRequest,
TripoConvertModelRequest
]
class TripoTaskOutput(BaseModel):
model: Optional[str] = Field(None, description='URL to the model')
base_model: Optional[str] = Field(None, description='URL to the base model')
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
class TripoTask(BaseModel):
task_id: str = Field(..., description='The task ID')
type: Optional[str] = Field(None, description='The type of task')
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
create_time: Optional[int] = Field(None, description='The creation time of the task')
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
queue_position: Optional[int] = Field(None, description='The position in the queue')
class TripoTaskResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoTask = Field(..., description='The task data')
class TripoGeneralResponse(BaseModel):
code: int = Field(0, description='The response code')
data: Dict[str, str] = Field(..., description='The task ID data')
class TripoBalanceData(BaseModel):
balance: float = Field(..., description='The account balance')
frozen: float = Field(..., description='The frozen balance')
class TripoBalanceResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoBalanceData = Field(..., description='The balance data')
class TripoErrorResponse(BaseModel):
code: int = Field(..., description='The error code')
message: str = Field(..., description='The error message')
suggestion: str = Field(..., description='The suggestion for fixing the error')

View File

@@ -1,6 +1,6 @@
import io
from inspect import cleandoc
from typing import Union, Optional
from typing import Union
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api_nodes.apis.bfl_api import (
BFLStatus,
@@ -9,7 +9,6 @@ from comfy_api_nodes.apis.bfl_api import (
BFLFluxCannyImageRequest,
BFLFluxDepthImageRequest,
BFLFluxProGenerateRequest,
BFLFluxKontextProGenerateRequest,
BFLFluxProUltraGenerateRequest,
BFLFluxProGenerateResponse,
)
@@ -270,158 +269,6 @@ class FluxProUltraImageNode(ComfyNodeABC):
return (output_image,)
class FluxKontextProImageNode(ComfyNodeABC):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation - specify what and how to edit.",
},
),
"aspect_ratio": (
IO.STRING,
{
"default": "16:9",
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
},
),
"guidance": (
IO.FLOAT,
{
"default": 3.0,
"min": 0.1,
"max": 99.0,
"step": 0.1,
"tooltip": "Guidance strength for the image generation process"
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 1,
"max": 150,
"tooltip": "Number of steps for the image generation process"
},
),
"seed": (
IO.INT,
{
"default": 1234,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
},
"optional": {
"input_image": (IO.IMAGE,),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@classmethod
def VALIDATE_INPUTS(cls, aspect_ratio: str):
try:
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
def api_call(
self,
prompt: str,
aspect_ratio: str,
guidance: float,
steps: int,
input_image: Optional[torch.Tensor]=None,
seed=0,
prompt_upsampling=False,
unique_id: Union[str, None] = None,
**kwargs,
):
if input_image is None:
validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=self.BFL_PATH,
method=HttpMethod.POST,
request_model=BFLFluxKontextProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxKontextProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
guidance=round(guidance, 1),
steps=steps,
seed=seed,
aspect_ratio=validate_aspect_ratio(
aspect_ratio,
minimum_ratio=self.MINIMUM_RATIO,
maximum_ratio=self.MAXIMUM_RATIO,
minimum_ratio_str=self.MINIMUM_RATIO_STR,
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
),
input_image=(
input_image
if input_image is None
else convert_image_to_base64(input_image)
)
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "")
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
class FluxProImageNode(ComfyNodeABC):
"""
@@ -1067,8 +914,6 @@ class FluxProDepthNode(ComfyNodeABC):
NODE_CLASS_MAPPINGS = {
"FluxProUltraImageNode": FluxProUltraImageNode,
# "FluxProImageNode": FluxProImageNode,
"FluxKontextProImageNode": FluxKontextProImageNode,
"FluxKontextMaxImageNode": FluxKontextMaxImageNode,
"FluxProExpandNode": FluxProExpandNode,
"FluxProFillNode": FluxProFillNode,
"FluxProCannyNode": FluxProCannyNode,
@@ -1079,8 +924,6 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
# "FluxProImageNode": "Flux 1.1 [pro] Image",
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
"FluxProExpandNode": "Flux.1 Expand Image",
"FluxProFillNode": "Flux.1 Fill Image",
"FluxProCannyNode": "Flux.1 Canny Control Image",

View File

@@ -1,446 +0,0 @@
"""
API Nodes for Gemini Multimodal LLM Usage via Remote API
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
import os
from enum import Enum
from typing import Optional, Literal
import torch
import folder_paths
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer
from comfy_api_nodes.apis import (
GeminiContent,
GeminiGenerateContentRequest,
GeminiGenerateContentResponse,
GeminiInlineData,
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
validate_string,
audio_to_base64_string,
video_to_base64_string,
tensor_to_base64_string,
)
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
class GeminiModel(str, Enum):
"""
Gemini Model Names allowed by comfy-api
"""
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
def get_gemini_endpoint(
model: GeminiModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
class GeminiNode(ComfyNodeABC):
"""
Node to generate text responses from a Gemini model.
This node allows users to interact with Google's Gemini AI models, providing
multimodal inputs (text, images, audio, video, files) to generate coherent
text responses. The node works with the latest Gemini models, handling the
API communication and response parsing.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
},
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiModel],
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
},
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"audio": (
IO.AUDIO,
{
"tooltip": "Optional audio to use as context for the model.",
"default": None,
},
),
"video": (
IO.VIDEO,
{
"tooltip": "Optional video to use as context for the model.",
"default": None,
},
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
RETURN_TYPES = ("STRING",)
FUNCTION = "api_call"
CATEGORY = "api node/text/Gemini"
API_NODE = True
def get_parts_from_response(
self, response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.
Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).
Returns:
List of response parts matching the requested type.
"""
parts = []
for part in self.get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.
Args:
response: The API response from Gemini.
Returns:
Combined text from all text parts in the response.
"""
parts = self.get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
"""
Convert video input to Gemini API compatible parts.
Args:
video_input: Video tensor from ComfyUI.
**kwargs: Additional arguments to pass to the conversion function.
Returns:
List of GeminiPart objects containing the encoded video.
"""
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
codec=VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
"""
Convert audio input to Gemini API compatible parts.
Args:
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
Returns:
List of GeminiPart objects containing the encoded audio.
"""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = {
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
"sample_rate": audio_input["sample_rate"],
}
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Args:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts
def create_text_part(self, text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
def api_call(
self,
prompt: str,
model: GeminiModel,
images: Optional[IO.IMAGE] = None,
audio: Optional[IO.AUDIO] = None,
video: Optional[IO.VIDEO] = None,
files: Optional[list[GeminiPart]] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[str]:
# Validate inputs
validate_string(prompt, strip_whitespace=False)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [self.create_text_part(prompt)]
# Add other modal parts
if images is not None:
image_parts = self.create_image_parts(images)
parts.extend(image_parts)
if audio is not None:
parts.extend(self.create_audio_parts(audio))
if video is not None:
parts.extend(self.create_video_parts(video))
if files is not None:
parts.extend(files)
# Create response
response = SynchronousOperation(
endpoint=get_gemini_endpoint(model),
request=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
)
]
),
auth_kwargs=kwargs,
).execute()
# Get result output
output_text = self.get_text_from_response(response)
if unique_id and output_text:
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
return (output_text or "Empty response from Gemini model...",)
class GeminiInputFiles(ComfyNodeABC):
"""
Loads and formats input files for use with the Gemini API.
This node allows users to include text (.txt) and PDF (.pdf) files as input
context for the Gemini model. Files are converted to the appropriate format
required by the API and can be chained together to include multiple files
in a single request.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
"""
For details about the supported file input types, see:
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
input_dir = folder_paths.get_input_directory()
input_files = [
f
for f in os.scandir(input_dir)
if f.is_file()
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE
]
input_files = sorted(input_files, key=lambda x: x.name)
input_files = [f.name for f in input_files]
return {
"required": {
"file": (
IO.COMBO,
{
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
"options": input_files,
"default": input_files[0] if input_files else None,
},
),
},
"optional": {
"GEMINI_INPUT_FILES": (
"GEMINI_INPUT_FILES",
{
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
"default": None,
},
),
},
}
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
FUNCTION = "prepare_files"
CATEGORY = "api node/text/Gemini"
def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = (
GeminiMimeType.pdf
if file_path.endswith(".pdf")
else GeminiMimeType.text_plain
)
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart(
inlineData=GeminiInlineData(
mimeType=mime_type,
data=base64_str,
)
)
def prepare_files(
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
) -> tuple[list[GeminiPart]]:
"""
Loads and formats input files for Gemini API.
"""
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = self.create_file_part(file_path)
files = [input_file_content] + GEMINI_INPUT_FILES
return (files,)
NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiInputFiles": GeminiInputFiles,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiInputFiles": "Gemini Input Files",
}

View File

@@ -1,86 +1,29 @@
import io
from typing import TypedDict, Optional
import json
import os
import time
import re
import uuid
from enum import Enum
from inspect import cleandoc
import numpy as np
import torch
from PIL import Image
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer
import folder_paths
from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse,
OpenAICreateResponse,
OpenAIResponse,
CreateModelResponseProperties,
Item,
Includable,
OutputContent,
InputImageContent,
Detail,
InputTextContent,
InputMessage,
InputMessageContentList,
InputContent,
InputFileContent,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
validate_and_cast_response,
validate_string,
tensor_to_base64_string,
text_filepath_to_data_uri,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
class HistoryEntry(TypedDict):
"""Type definition for a single history entry in the chat."""
prompt: str
response: str
response_id: str
timestamp: float
class ChatHistory(TypedDict):
"""Type definition for the chat history dictionary."""
__annotations__: dict[str, list[HistoryEntry]]
class SupportedOpenAIModel(str, Enum):
o4_mini = "o4-mini"
o1 = "o1"
o3 = "o3"
o1_pro = "o1-pro"
gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
class OpenAIDalle2(ComfyNodeABC):
"""
@@ -172,7 +115,7 @@ class OpenAIDalle2(ComfyNodeABC):
n=1,
size="1024x1024",
unique_id=None,
**kwargs,
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-2"
@@ -319,7 +262,7 @@ class OpenAIDalle3(ComfyNodeABC):
quality="standard",
size="1024x1024",
unique_id=None,
**kwargs,
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-3"
@@ -457,12 +400,12 @@ class OpenAIGPTImage1(ComfyNodeABC):
n=1,
size="1024x1024",
unique_id=None,
**kwargs,
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
content_type = "application/json"
content_type="application/json"
request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
@@ -471,7 +414,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
content_type = "multipart/form-data"
content_type ="multipart/form-data"
batch_size = image.shape[0]
@@ -543,466 +486,17 @@ class OpenAIGPTImage1(ComfyNodeABC):
return (img_tensor,)
class OpenAITextNode(ComfyNodeABC):
"""
Base class for OpenAI text generation nodes.
"""
RETURN_TYPES = (IO.STRING,)
FUNCTION = "api_call"
CATEGORY = "api node/text/OpenAI"
API_NODE = True
class OpenAIChatNode(OpenAITextNode):
"""
Node to generate text responses from an OpenAI model.
"""
def __init__(self) -> None:
"""Initialize the chat node with a new session ID and empty history."""
self.current_session_id: str = str(uuid.uuid4())
self.history: dict[str, list[HistoryEntry]] = {}
self.previous_response_id: Optional[str] = None
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text inputs to the model, used to generate a response.",
},
),
"persist_context": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Persist chat context between calls (multi-turn conversation)",
},
),
"model": model_field_to_node_input(
IO.COMBO,
OpenAICreateResponse,
"model",
enum_type=SupportedOpenAIModel,
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"files": (
"OPENAI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.",
},
),
"advanced_options": (
"OPENAI_CHAT_CONFIG",
{
"default": None,
"tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate text responses from an OpenAI model."
def get_result_response(
self,
response_id: str,
include: Optional[list[Includable]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
) -> OpenAIResponse:
"""
Retrieve a model response with the given ID from the OpenAI API.
Args:
response_id (str): The ID of the response to retrieve.
include (Optional[List[Includable]]): Additional fields to include
in the response. See the `include` parameter for Response
creation above for more information.
"""
return PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{RESPONSES_ENDPOINT}/{response_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=OpenAIResponse,
query_params={"include": include},
),
completed_statuses=["completed"],
failed_statuses=["failed"],
status_extractor=lambda response: response.status,
auth_kwargs=auth_kwargs,
).execute()
def get_message_content_from_response(
self, response: OpenAIResponse
) -> list[OutputContent]:
"""Extract message content from the API response."""
for output in response.output:
if output.root.type == "message":
return output.root.content
raise TypeError("No output message found in response")
def get_text_from_message_content(
self, message_content: list[OutputContent]
) -> str:
"""Extract text content from message content."""
for content_item in message_content:
if content_item.root.type == "output_text":
return str(content_item.root.text)
return "No text output found in response"
def get_history_text(self, session_id: str) -> str:
"""Convert the entire history for a given session to JSON string."""
return json.dumps(self.history[session_id])
def display_history_on_node(self, session_id: str, node_id: str) -> None:
"""Display formatted chat history on the node UI."""
render_spec = {
"node_id": node_id,
"component": "ChatHistoryWidget",
"props": {
"history": self.get_history_text(session_id),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
def add_to_history(
self, session_id: str, prompt: str, output_text: str, response_id: str
) -> None:
"""Add a new entry to the chat history."""
if session_id not in self.history:
self.history[session_id] = []
self.history[session_id].append(
{
"prompt": prompt,
"response": output_text,
"response_id": response_id,
"timestamp": time.time(),
}
)
def parse_output_text_from_response(self, response: OpenAIResponse) -> str:
"""Extract text output from the API response."""
message_contents = self.get_message_content_from_response(response)
return self.get_text_from_message_content(message_contents)
def generate_new_session_id(self) -> str:
"""Generate a new unique session ID."""
return str(uuid.uuid4())
def get_session_id(self, persist_context: bool) -> str:
"""Get the current or generate a new session ID based on context persistence."""
return (
self.current_session_id
if persist_context
else self.generate_new_session_id()
)
def tensor_to_input_image_content(
self, image: torch.Tensor, detail_level: Detail = "auto"
) -> InputImageContent:
"""Convert a tensor to an input image content object."""
return InputImageContent(
detail=detail_level,
image_url=f"data:image/png;base64,{tensor_to_base64_string(image)}",
type="input_image",
)
def create_input_message_contents(
self,
prompt: str,
image: Optional[torch.Tensor] = None,
files: Optional[list[InputFileContent]] = None,
) -> InputMessageContentList:
"""Create a list of input message contents from prompt and optional image."""
content_list: list[InputContent] = [
InputTextContent(text=prompt, type="input_text"),
]
if image is not None:
for i in range(image.shape[0]):
content_list.append(
self.tensor_to_input_image_content(image[i].unsqueeze(0))
)
if files is not None:
content_list.extend(files)
return InputMessageContentList(
root=content_list,
)
def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]:
"""Extract response ID from prompt if it exists."""
parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt)
return parsed_id.group(1) if parsed_id else None
def strip_response_tag_from_prompt(self, prompt: str) -> str:
"""Remove the response ID tag from the prompt."""
return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip())
def delete_history_after_response_id(
self, new_start_id: str, session_id: str
) -> None:
"""Delete history entries after a specific response ID."""
if session_id not in self.history:
return
new_history = []
i = 0
while (
i < len(self.history[session_id])
and self.history[session_id][i]["response_id"] != new_start_id
):
new_history.append(self.history[session_id][i])
i += 1
# Since it's the new starting point (not the response being edited), we include it as well
if i < len(self.history[session_id]):
new_history.append(self.history[session_id][i])
self.history[session_id] = new_history
def api_call(
self,
prompt: str,
persist_context: bool,
model: SupportedOpenAIModel,
unique_id: Optional[str] = None,
images: Optional[torch.Tensor] = None,
files: Optional[list[InputFileContent]] = None,
advanced_options: Optional[CreateModelResponseProperties] = None,
**kwargs,
) -> tuple[str]:
# Validate inputs
validate_string(prompt, strip_whitespace=False)
session_id = self.get_session_id(persist_context)
response_id_override = self.parse_response_id_from_prompt(prompt)
if response_id_override:
is_starting_from_beginning = response_id_override == "start"
if is_starting_from_beginning:
self.history[session_id] = []
previous_response_id = None
else:
previous_response_id = response_id_override
self.delete_history_after_response_id(response_id_override, session_id)
prompt = self.strip_response_tag_from_prompt(prompt)
elif persist_context:
previous_response_id = self.previous_response_id
else:
previous_response_id = None
# Create response
create_response = SynchronousOperation(
endpoint=ApiEndpoint(
path=RESPONSES_ENDPOINT,
method=HttpMethod.POST,
request_model=OpenAICreateResponse,
response_model=OpenAIResponse,
),
request=OpenAICreateResponse(
input=[
Item(
root=InputMessage(
content=self.create_input_message_contents(
prompt, images, files
),
role="user",
)
),
],
store=True,
stream=False,
model=model,
previous_response_id=previous_response_id,
**(
advanced_options.model_dump(exclude_none=True)
if advanced_options
else {}
),
),
auth_kwargs=kwargs,
).execute()
response_id = create_response.id
# Get result output
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
output_text = self.parse_output_text_from_response(result_response)
# Update history
self.add_to_history(session_id, prompt, output_text, response_id)
self.display_history_on_node(session_id, unique_id)
self.previous_response_id = response_id
return (output_text,)
class OpenAIInputFiles(ComfyNodeABC):
"""
Loads and formats input files for OpenAI API.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
"""
For details about the supported file input types, see:
https://platform.openai.com/docs/guides/pdf-files?api-mode=responses
"""
input_dir = folder_paths.get_input_directory()
input_files = [
f
for f in os.scandir(input_dir)
if f.is_file()
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
and f.stat().st_size < 32 * 1024 * 1024
]
input_files = sorted(input_files, key=lambda x: x.name)
input_files = [f.name for f in input_files]
return {
"required": {
"file": (
IO.COMBO,
{
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
"options": input_files,
"default": input_files[0] if input_files else None,
},
),
},
"optional": {
"OPENAI_INPUT_FILES": (
"OPENAI_INPUT_FILES",
{
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
"default": None,
},
),
},
}
DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes."
RETURN_TYPES = ("OPENAI_INPUT_FILES",)
FUNCTION = "prepare_files"
CATEGORY = "api node/text/OpenAI"
def create_input_file_content(self, file_path: str) -> InputFileContent:
return InputFileContent(
file_data=text_filepath_to_data_uri(file_path),
filename=os.path.basename(file_path),
type="input_file",
)
def prepare_files(
self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []
) -> tuple[list[InputFileContent]]:
"""
Loads and formats input files for OpenAI API.
"""
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = self.create_input_file_content(file_path)
files = [input_file_content] + OPENAI_INPUT_FILES
return (files,)
class OpenAIChatConfig(ComfyNodeABC):
"""Allows setting additional configuration for the OpenAI Chat Node."""
RETURN_TYPES = ("OPENAI_CHAT_CONFIG",)
FUNCTION = "configure"
DESCRIPTION = (
"Allows specifying advanced configuration options for the OpenAI Chat Nodes."
)
CATEGORY = "api node/text/OpenAI"
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"truncation": (
IO.COMBO,
{
"options": ["auto", "disabled"],
"default": "auto",
"tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error",
},
),
},
"optional": {
"max_output_tokens": model_field_to_node_input(
IO.INT,
OpenAICreateResponse,
"max_output_tokens",
min=16,
default=4096,
max=16384,
tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens",
),
"instructions": model_field_to_node_input(
IO.STRING, OpenAICreateResponse, "instructions", multiline=True
),
},
}
def configure(
self,
truncation: bool,
instructions: Optional[str] = None,
max_output_tokens: Optional[int] = None,
) -> tuple[CreateModelResponseProperties]:
"""
Configure advanced options for the OpenAI Chat Node.
Note:
While `top_p` and `temperature` are listed as properties in the
spec, they are not supported for all models (e.g., o4-mini).
They are not exposed as inputs at all to avoid having to manually
remove depending on model choice.
"""
return (
CreateModelResponseProperties(
instructions=instructions,
truncation=truncation,
max_output_tokens=max_output_tokens,
),
)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"OpenAIDalle2": OpenAIDalle2,
"OpenAIDalle3": OpenAIDalle3,
"OpenAIGPTImage1": OpenAIGPTImage1,
"OpenAIChatNode": OpenAIChatNode,
"OpenAIInputFiles": OpenAIInputFiles,
"OpenAIChatConfig": OpenAIChatConfig,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"OpenAIDalle2": "OpenAI DALL·E 2",
"OpenAIDalle3": "OpenAI DALL·E 3",
"OpenAIGPTImage1": "OpenAI GPT Image 1",
"OpenAIChatNode": "OpenAI Chat",
"OpenAIInputFiles": "OpenAI Chat Input Files",
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
}

View File

@@ -6,42 +6,40 @@ Pika API docs: https://pika-827374fb.mintlify.app/api-reference
from __future__ import annotations
import io
import logging
from typing import Optional, TypeVar
import numpy as np
import logging
import torch
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
from comfy_api.input_impl import VideoFromFile
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
)
import numpy as np
from comfy_api_nodes.apis import (
IngredientsMode,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaDurationEnum,
Pikaffect,
PikaGenerateResponse,
PikaResolutionEnum,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaVideoResponse,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
IngredientsMode,
PikaDurationEnum,
PikaResolutionEnum,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
Pikaffect,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
EmptyRequest,
HttpMethod,
PollingOperation,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
download_url_to_video_output,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
R = TypeVar("R")
@@ -206,7 +204,6 @@ class PikaImageToVideoV2_2(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -460,7 +457,7 @@ class PikAdditionsNode(PikaNodeBase):
},
}
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what youd like to add to create a seamlessly integrated result."
def api_call(
self,

View File

@@ -1,462 +0,0 @@
"""
ComfyUI X Rodin3D(Deemos) API Nodes
Rodin API docs: https://developer.hyper3d.ai/
"""
from __future__ import annotations
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths
import requests
import os
import datetime
import shutil
import time
import io
import logging
import math
from PIL import Image
from comfy_api_nodes.apis.rodin_api import (
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
Rodin3DCheckStatusRequest,
Rodin3DCheckStatusResponse,
Rodin3DDownloadRequest,
Rodin3DDownloadResponse,
JobStatus,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
)
COMMON_PARAMETERS = {
"Seed": (
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
),
"Material_Type": (
IO.COMBO,
{
"options": ["PBR", "Shaded"],
"default": "PBR"
}
),
"Polygon_count": (
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
"default": "18K-Quad"
}
)
}
def create_task_error(response: Rodin3DGenerateResponse):
"""Check if the response has error"""
return hasattr(response, "error")
class Rodin3DAPI:
"""
Generate 3D Assets using Rodin API
"""
RETURN_TYPES = (IO.STRING,)
RETURN_NAMES = ("3D Model Path",)
CATEGORY = "api node/3d/Rodin"
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "api_call"
API_NODE = True
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
if has_failed:
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
elif all_done:
return "DONE"
else:
return "Generating"
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
if images == None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) >= 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality=quality,
mesh_mode=mesh_mode
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type = "multipart/form-data",
auth_kwargs=kwargs,
)
response = operation.execute()
if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status"
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path = path,
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(
subscription_key = subscription_key
),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=self.check_rodin_status,
poll_interval=3.0,
auth_kwargs=kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return poll_operation.execute()
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(
task_uuid=uuid
),
auth_kwargs=kwargs
)
return operation.execute()
def GetQualityAndMode(self, PolyCount):
if PolyCount == "200K-Triangle":
mesh_mode = "Raw"
quality = "medium"
else:
mesh_mode = "Quad"
if PolyCount == "4K-Quad":
quality = "extra-low"
elif PolyCount == "8K-Quad":
quality = "low"
elif PolyCount == "18K-Quad":
quality = "medium"
elif PolyCount == "50K-Quad":
quality = "high"
else:
quality = "medium"
return mesh_mode, quality
def DownLoadFiles(self, Url_List):
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(Save_path, exist_ok=True)
model_file_path = None
for Item in Url_List.list:
url = Item.url
file_name = Item.name
file_path = os.path.join(Save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(file_path, "wb") as f:
shutil.copyfileobj(r.raw, f)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
time.sleep(2)
else:
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
return model_file_path
class Rodin3D_Regular(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Regular"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Detail(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Detail"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Smooth(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Smooth"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Sketch(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
"Seed":
(
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
**kwargs
):
tier = "Sketch"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
material_type = "PBR"
quality = "medium"
mesh_mode = "Quad"
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Rodin3D_Regular": Rodin3D_Regular,
"Rodin3D_Detail": Rodin3D_Detail,
"Rodin3D_Smooth": Rodin3D_Smooth,
"Rodin3D_Sketch": Rodin3D_Sketch,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
}

View File

@@ -1,635 +0,0 @@
"""Runway API Nodes
API Docs:
- https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete
User Guides:
- https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha
- https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video
- https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo
- https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3
"""
from typing import Union, Optional, Any
from enum import Enum
import torch
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
RunwayTaskStatusResponse as TaskStatusResponse,
RunwayTaskStatusEnum as TaskStatus,
RunwayModelEnum as Model,
RunwayDurationEnum as Duration,
RunwayAspectRatioEnum as AspectRatio,
RunwayPromptImageObject,
RunwayPromptImageDetailedObject,
RunwayTextToImageRequest,
RunwayTextToImageResponse,
Model4,
ReferenceImage,
RunwayTextToImageAspectRatioEnum,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_video_output,
image_tensor_pair_to_batch,
validate_string,
download_url_to_image_tensor,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
AVERAGE_DURATION_I2V_SECONDS = 64
AVERAGE_DURATION_FLF_SECONDS = 256
AVERAGE_DURATION_T2I_SECONDS = 41
class RunwayApiError(Exception):
"""Base exception for Runway API errors."""
pass
class RunwayGen4TurboAspectRatio(str, Enum):
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
field_1280_720 = "1280:720"
field_720_1280 = "720:1280"
field_1104_832 = "1104:832"
field_832_1104 = "832:1104"
field_960_960 = "960:960"
field_1584_672 = "1584:672"
class RunwayGen3aAspectRatio(str, Enum):
"""Aspect ratios supported for Image to Video API when using gen3a_turbo model."""
field_768_1280 = "768:1280"
field_1280_768 = "1280:768"
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
return response.output[0]
return None
# TODO: replace with updated image validation utils (upstream)
def validate_input_image(image: torch.Tensor) -> bool:
"""
Validate the input image is within the size limits for the Runway API.
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
"""
return image.shape[2] < 8000 and image.shape[1] < 8000
def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
TaskStatus.SUCCEEDED.value,
],
failed_statuses=[
TaskStatus.FAILED.value,
TaskStatus.CANCELLED.value,
],
status_extractor=lambda response: (response.status.value),
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
node_id=node_id,
progress_extractor=extract_progress_from_task_status,
).execute()
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the image URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
return response.output[0]
return None
class RunwayVideoGenNode(ComfyNodeABC):
"""Runway Video Node Base."""
RETURN_TYPES = ("VIDEO",)
FUNCTION = "api_call"
CATEGORY = "api node/video/Runway"
API_NODE = True
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no video data found in response."
)
return True
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
"""Poll the task status until it is finished then get the response."""
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
)
def generate_video(
self,
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
)
initial_response = initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
final_response = self.get_response(task_id, auth_kwargs, node_id)
self.validate_response(final_response)
video_url = get_video_url_from_task_status(final_response)
return (download_url_to_video_output(video_url),)
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen3a Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
# Upload image
download_urls = upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen4 Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen4TurboAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
# Upload image
download_urls = upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
"""Runway First-Last Frame Node."""
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"end_frame": (
IO.IMAGE,
{
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"unique_id": "UNIQUE_ID",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
validate_input_image(end_frame)
# Upload images
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = upload_images_to_comfyapi(
stacked_input_images,
max_images=2,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayTextToImageNode(ComfyNodeABC):
"""Runway Text to Image Node."""
RETURN_TYPES = ("IMAGE",)
FUNCTION = "api_call"
CATEGORY = "api node/image/Runway"
API_NODE = True
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayTextToImageRequest,
"ratio",
enum_type=RunwayTextToImageAspectRatioEnum,
),
},
"optional": {
"reference_image": (
IO.IMAGE,
{"tooltip": "Optional reference image to guide the generation"},
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: TaskStatusResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no image data found in response."
)
return True
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
node_id=node_id,
)
def api_call(
self,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[torch.Tensor]:
# Validate inputs
validate_string(prompt, min_length=1)
# Prepare reference images if provided
reference_images = None
if reference_image is not None:
validate_input_image(reference_image)
download_urls = upload_images_to_comfyapi(
reference_image,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload reference image to comfy api.")
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
# Create request
request = RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
ratio=ratio,
referenceImages=reference_images,
)
# Execute initial request
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_IMAGE,
method=HttpMethod.POST,
request_model=RunwayTextToImageRequest,
response_model=RunwayTextToImageResponse,
),
request=request,
auth_kwargs=kwargs,
)
initial_response = initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
# Poll for completion
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
self.validate_response(final_response)
# Download and return image
image_url = get_image_url_from_task_status(final_response)
return (download_url_to_image_tensor(image_url),)
NODE_CLASS_MAPPINGS = {
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
"RunwayTextToImageNode": RunwayTextToImageNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
"RunwayTextToImageNode": "Runway Text to Image",
}

View File

@@ -1,574 +0,0 @@
import os
from folder_paths import get_output_directory
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis import (
TripoOrientation,
TripoModelVersion,
)
from comfy_api_nodes.apis.tripo_api import (
TripoTaskType,
TripoStyle,
TripoFileReference,
TripoFileEmptyReference,
TripoUrlReference,
TripoTaskResponse,
TripoTaskStatus,
TripoTextToModelRequest,
TripoImageToModelRequest,
TripoMultiviewToModelRequest,
TripoTextureModelRequest,
TripoRefineModelRequest,
TripoAnimateRigRequest,
TripoAnimateRetargetRequest,
TripoConvertModelRequest,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_bytesio,
)
def upload_image_to_tripo(image, **kwargs):
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
def get_model_url_from_response(response: TripoTaskResponse) -> str:
if response.data is not None:
for key in ["pbr_model", "model", "base_model"]:
if getattr(response.data.output, key, None) is not None:
return getattr(response.data.output, key)
raise RuntimeError(f"Failed to get model url from response: {response}")
def poll_until_finished(
kwargs: dict[str, str],
response: TripoTaskResponse,
) -> tuple[str, str]:
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
if response.code != 0:
raise RuntimeError(f"Failed to generate mesh: {response.error}")
task_id = response.data.task_id
response_poll = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TripoTaskResponse,
),
completed_statuses=[TripoTaskStatus.SUCCESS],
failed_statuses=[
TripoTaskStatus.FAILED,
TripoTaskStatus.CANCELLED,
TripoTaskStatus.UNKNOWN,
TripoTaskStatus.BANNED,
TripoTaskStatus.EXPIRED,
],
status_extractor=lambda x: x.data.status,
auth_kwargs=kwargs,
node_id=kwargs["unique_id"],
result_url_extractor=get_model_url_from_response,
progress_extractor=lambda x: x.data.progress,
).execute()
if response_poll.data.status == TripoTaskStatus.SUCCESS:
url = get_model_url_from_response(response_poll)
bytesio = download_url_to_bytesio(url)
# Save the downloaded model file
model_file = f"tripo_model_{task_id}.glb"
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
f.write(bytesio.getvalue())
return model_file, task_id
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
class TripoTextToModelNode:
"""
Generates 3D models synchronously based on a text prompt using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": ("STRING", {"multiline": True}),
},
"optional": {
"negative_prompt": ("STRING", {"multiline": True}),
"model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion),
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"image_seed": ("INT", {"default": 42}),
"model_seed": ("INT", {"default": 42}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style
if not prompt:
raise RuntimeError("Prompt is required")
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoTextToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoTextToModelRequest(
type=TripoTaskType.TEXT_TO_MODEL,
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
model_version=model_version,
style=style_enum,
texture=texture,
pbr=pbr,
image_seed=image_seed,
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
auto_size=True,
quad=quad
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoImageToModelNode:
"""
Generates 3D models synchronously based on a single image using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
"optional": {
"model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion),
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"model_seed": ("INT", {"default": 42}),
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style
if image is None:
raise RuntimeError("Image is required")
tripo_file = upload_image_to_tripo(image, **kwargs)
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoImageToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoImageToModelRequest(
type=TripoTaskType.IMAGE_TO_MODEL,
file=tripo_file,
model_version=model_version,
style=style_enum,
texture=texture,
pbr=pbr,
model_seed=model_seed,
orientation=orientation,
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
auto_size=True,
quad=quad
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoMultiviewToModelNode:
"""
Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
"optional": {
"image_left": ("IMAGE",),
"image_back": ("IMAGE",),
"image_right": ("IMAGE",),
"model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion),
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"model_seed": ("INT", {"default": 42}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
if image is None:
raise RuntimeError("front image for multiview is required")
images = []
image_dict = {
"image": image,
"image_left": image_left,
"image_back": image_back,
"image_right": image_right
}
if image_left is None and image_back is None and image_right is None:
raise RuntimeError("At least one of left, back, or right image must be provided for multiview")
for image_name in ["image", "image_left", "image_back", "image_right"]:
image_ = image_dict[image_name]
if image_ is not None:
tripo_file = upload_image_to_tripo(image_, **kwargs)
images.append(tripo_file)
else:
images.append(TripoFileEmptyReference())
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoMultiviewToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoMultiviewToModelRequest(
type=TripoTaskType.MULTIVIEW_TO_MODEL,
files=images,
model_version=model_version,
orientation=orientation,
texture=texture,
pbr=pbr,
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
quad=quad,
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoTextureNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_task_id": ("MODEL_TASK_ID",),
},
"optional": {
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 80
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoTextureModelRequest,
response_model=TripoTaskResponse,
),
request=TripoTextureModelRequest(
original_model_task_id=model_task_id,
texture=texture,
pbr=pbr,
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRefineNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_task_id": ("MODEL_TASK_ID", {
"tooltip": "Must be a v1.4 Tripo model"
}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only."
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 240
def generate_mesh(self, model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoRefineModelRequest,
response_model=TripoTaskResponse,
),
request=TripoRefineModelRequest(
draft_model_task_id=model_task_id
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRigNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("MODEL_TASK_ID",),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "RIG_TASK_ID")
RETURN_NAMES = ("model_file", "rig task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 180
def generate_mesh(self, original_model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoAnimateRigRequest,
response_model=TripoTaskResponse,
),
request=TripoAnimateRigRequest(
original_model_task_id=original_model_task_id,
out_format="glb",
spec="tripo"
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRetargetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("RIG_TASK_ID",),
"animation": ([
"preset:idle",
"preset:walk",
"preset:climb",
"preset:jump",
"preset:slash",
"preset:shoot",
"preset:hurt",
"preset:fall",
"preset:turn",
],),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "RETARGET_TASK_ID")
RETURN_NAMES = ("model_file", "retarget task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 30
def generate_mesh(self, animation, original_model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoAnimateRetargetRequest,
response_model=TripoTaskResponse,
),
request=TripoAnimateRetargetRequest(
original_model_task_id=original_model_task_id,
animation=animation,
out_format="glb",
bake_animation=True
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoConversionNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",),
"format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],),
},
"optional": {
"quad": ("BOOLEAN", {"default": False}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}),
"texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@classmethod
def VALIDATE_INPUTS(cls, input_types):
# The min and max of input1 and input2 are still validated because
# we didn't take `input1` or `input2` as arguments
if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"):
return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type"
return True
RETURN_TYPES = ()
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 30
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoConvertModelRequest,
response_model=TripoTaskResponse,
),
request=TripoConvertModelRequest(
original_model_task_id=original_model_task_id,
format=format,
quad=quad if quad else None,
face_limit=face_limit if face_limit != -1 else None,
texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
NODE_CLASS_MAPPINGS = {
"TripoTextToModelNode": TripoTextToModelNode,
"TripoImageToModelNode": TripoImageToModelNode,
"TripoMultiviewToModelNode": TripoMultiviewToModelNode,
"TripoTextureNode": TripoTextureNode,
"TripoRefineNode": TripoRefineNode,
"TripoRigNode": TripoRigNode,
"TripoRetargetNode": TripoRetargetNode,
"TripoConversionNode": TripoConversionNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TripoTextToModelNode": "Tripo: Text to Model",
"TripoImageToModelNode": "Tripo: Image to Model",
"TripoMultiviewToModelNode": "Tripo: Multiview to Model",
"TripoTextureNode": "Tripo: Texture model",
"TripoRefineNode": "Tripo: Refine Draft model",
"TripoRigNode": "Tripo: Rig model",
"TripoRetargetNode": "Tripo: Retarget rigged model",
"TripoConversionNode": "Tripo: Convert model",
}

View File

@@ -14,7 +14,6 @@ import re
from io import BytesIO
from inspect import cleandoc
import torch
import comfy.utils
from comfy.comfy_types import FileLocator
@@ -230,186 +229,6 @@ class SVG:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch:
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"direction": (["right", "down", "left", "up"], {"default": "right"}),
"match_image_size": ("BOOLEAN", {"default": True}),
"spacing_width": (
"INT",
{"default": 0, "min": 0, "max": 1024, "step": 2},
),
"spacing_color": (
["white", "black", "red", "green", "blue"],
{"default": "white"},
),
},
"optional": {
"image2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stitch"
CATEGORY = "image/transform"
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1,
direction,
match_image_size,
spacing_width,
spacing_color,
image2=None,
):
if image2 is None:
return (image1,)
# Handle batch size differences
if image1.shape[0] != image2.shape[0]:
max_batch = max(image1.shape[0], image2.shape[0])
if image1.shape[0] < max_batch:
image1 = torch.cat(
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
)
if image2.shape[0] < max_batch:
image2 = torch.cat(
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
)
# Match image sizes if requested
if match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
aspect_ratio = w2 / h2
if direction in ["left", "right"]:
target_h, target_w = h1, int(h1 * aspect_ratio)
else: # up, down
target_w, target_h = w1, int(w1 / aspect_ratio)
image2 = comfy.utils.common_upscale(
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
).movedim(1, -1)
# When not matching sizes, pad to align non-concat dimensions
if not match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
if direction in ["left", "right"]:
# For horizontal concat, pad heights to match
if h1 != h2:
target_h = max(h1, h2)
if h1 < target_h:
pad_h = target_h - h1
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
if h2 < target_h:
pad_h = target_h - h2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
else: # up, down
# For vertical concat, pad widths to match
if w1 != w2:
target_w = max(w1, w2)
if w1 < target_w:
pad_w = target_w - w1
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
if w2 < target_w:
pad_w = target_w - w2
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
# Ensure same number of channels
if image1.shape[-1] != image2.shape[-1]:
max_channels = max(image1.shape[-1], image2.shape[-1])
if image1.shape[-1] < max_channels:
image1 = torch.cat(
[
image1,
torch.ones(
*image1.shape[:-1],
max_channels - image1.shape[-1],
device=image1.device,
),
],
dim=-1,
)
if image2.shape[-1] < max_channels:
image2 = torch.cat(
[
image2,
torch.ones(
*image2.shape[:-1],
max_channels - image2.shape[-1],
device=image2.device,
),
],
dim=-1,
)
# Add spacing if specified
if spacing_width > 0:
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
color_map = {
"white": 1.0,
"black": 0.0,
"red": (1.0, 0.0, 0.0),
"green": (0.0, 1.0, 0.0),
"blue": (0.0, 0.0, 1.0),
}
color_val = color_map[spacing_color]
if direction in ["left", "right"]:
spacing_shape = (
image1.shape[0],
max(image1.shape[1], image2.shape[1]),
spacing_width,
image1.shape[-1],
)
else:
spacing_shape = (
image1.shape[0],
spacing_width,
max(image1.shape[2], image2.shape[2]),
image1.shape[-1],
)
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
if isinstance(color_val, tuple):
for i, c in enumerate(color_val):
if i < spacing.shape[-1]:
spacing[..., i] = c
if spacing.shape[-1] == 4: # Add alpha
spacing[..., 3] = 1.0
else:
spacing[..., : min(3, spacing.shape[-1])] = color_val
if spacing.shape[-1] == 4:
spacing[..., 3] = 1.0
# Concatenate images
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
if spacing_width > 0:
images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),)
class SaveSVGNode:
"""
Save SVG files on disk.
@@ -499,5 +318,4 @@ NODE_CLASS_MAPPINGS = {
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
"ImageStitch": ImageStitch,
}

View File

@@ -16,7 +16,7 @@ class Load3D():
os.makedirs(input_dir, exist_ok=True)
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),

View File

@@ -296,41 +296,6 @@ class RegexExtract():
return result,
class RegexReplace():
DESCRIPTION = "Find and replace text using regex patterns."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True}),
},
"optional": {
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
return result,
NODE_CLASS_MAPPINGS = {
"StringConcatenate": StringConcatenate,
"StringSubstring": StringSubstring,
@@ -341,8 +306,7 @@ NODE_CLASS_MAPPINGS = {
"StringContains": StringContains,
"StringCompare": StringCompare,
"RegexMatch": RegexMatch,
"RegexExtract": RegexExtract,
"RegexReplace": RegexReplace,
"RegexExtract": RegexExtract
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -355,6 +319,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StringContains": "Contains",
"StringCompare": "Compare",
"RegexMatch": "Regex Match",
"RegexExtract": "Regex Extract",
"RegexReplace": "Regex Replace",
"RegexExtract": "Regex Extract"
}

View File

@@ -1,5 +1,4 @@
from comfy_api.torch_helpers import set_torch_compile_wrapper
import torch
class TorchCompileModel:
@classmethod
@@ -15,7 +14,7 @@ class TorchCompileModel:
def patch(self, model, backend):
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend)
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
return (m, )
NODE_CLASS_MAPPINGS = {

View File

@@ -1,67 +0,0 @@
import torch
from comfy_api.v3.io import (
ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay,
IntegerInput, MaskInput, ImageInput, ComboDynamicInput, NodeOutput,
)
class V3TestNode(ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return SchemaV3(
node_id="V3TestNode1",
display_name="V3 Test Node (1djekjd)",
description="This is a funky V3 node test.",
category="v3 nodes",
inputs=[
IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
MaskInput("mask", behavior=InputBehavior.optional),
ImageInput("image", display_name="new_image"),
# IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider, ),
# ComboDynamicInput("mask", behavior=InputBehavior.optional),
# IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider,
# dependent_inputs=[ComboDynamicInput("mask", behavior=InputBehavior.optional)],
# dependent_values=[lambda my_value: IO.STRING if my_value < 5 else IO.NUMBER],
# ),
# ["option1", "option2". "option3"]
# ComboDynamicInput["sdfgjhl", [ComboDynamicOptions("option1", [IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider, ImageInput(), MaskInput(), String()]),
# CombyDynamicOptons("option2", [])
# ]]
],
is_output_node=True,
)
def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs):
a = NodeOutput(1)
aa = NodeOutput(1, "hellothere")
ab = NodeOutput(1, "hellothere", ui={"lol": "jk"})
b = NodeOutput()
c = NodeOutput(ui={"lol": "jk"})
return NodeOutput()
return NodeOutput(1)
return NodeOutput(1, block_execution="Kill yourself")
return ()
NODES_LIST: list[ComfyNodeV3] = [
V3TestNode,
]
# NODE_CLASS_MAPPINGS = {}
# NODE_DISPLAY_NAME_MAPPINGS = {}
# for node in NODES_LIST:
# schema = node.GET_SCHEMA()
# NODE_CLASS_MAPPINGS[schema.node_id] = node
# if schema.display_name:
# NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name

View File

@@ -268,9 +268,8 @@ class WanVaceToVideo:
trim_latent = reference_image.shape[2]
mask = mask.unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
@@ -345,44 +344,6 @@ class WanCameraImageToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanPhantomSubjectToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"images": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond2 = negative
if images is not None:
images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
latent_images = []
for i in images:
latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])]
concat_latent_image = torch.cat(latent_images, dim=2)
positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image})
cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image})
negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))})
out_latent = {}
out_latent["samples"] = latent
return (positive, cond2, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
@@ -391,5 +352,4 @@ NODE_CLASS_MAPPINGS = {
"WanVaceToVideo": WanVaceToVideo,
"TrimVideoLatent": TrimVideoLatent,
"WanCameraImageToVideo": WanCameraImageToVideo,
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.39"
__version__ = "0.3.34"

View File

@@ -17,7 +17,6 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput
class ExecutionResult(Enum):
SUCCESS = 0
@@ -243,22 +242,6 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
elif isinstance(r, NodeOutput):
if r.ui is not None:
uis.append(r.ui.as_dict())
if r.expand is not None:
has_subgraph = True
new_graph = r.expand
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif r.result is not None:
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))
@@ -926,6 +909,7 @@ class PromptQueue:
self.currently_running = {}
self.history = {}
self.flags = {}
server.prompt_queue = self
def put(self, item):
with self.mutex:
@@ -970,7 +954,6 @@ class PromptQueue:
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
# Note: slow
def get_current_queue(self):
with self.mutex:
out = []
@@ -978,13 +961,6 @@ class PromptQueue:
out += [x]
return (out, copy.deepcopy(self.queue))
# read-safe as long as queue items are immutable
def get_current_queue_volatile(self):
with self.mutex:
running = [x for x in self.currently_running.values()]
queued = copy.copy(self.queue)
return (running, queued)
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)

View File

@@ -260,6 +260,7 @@ def start_comfyui(asyncio_loop=None):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)
hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
@@ -270,7 +271,7 @@ def start_comfyui(asyncio_loop=None):
prompt_server.add_routes()
hijack_progress(prompt_server)
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
if args.quick_test_for_ci:
exit(0)

View File

@@ -5,18 +5,12 @@ from comfy.cli_args import args
from PIL import ImageFile, UnidentifiedImageError
def conditioning_set_values(conditioning, values={}, append=False):
def conditioning_set_values(conditioning, values={}):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
val = values[k]
if append:
old_val = n[1].get(k, None)
if old_val is not None:
val = old_val + val
n[1][k] = val
n[1][k] = values[k]
c.append(n)
return c

View File

@@ -26,7 +26,6 @@ import comfy.sd
import comfy.utils
import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.v3.io import ComfyNodeV3
import comfy.clip_vision
@@ -1104,7 +1103,16 @@ class unCLIPConditioning:
if strength == 0:
return (conditioning, )
c = node_helpers.conditioning_set_values(conditioning, {"unclip_conditioning": [{"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}]}, append=True)
c = []
for t in conditioning:
o = t[1].copy()
x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
if "unclip_conditioning" in o:
o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
else:
o["unclip_conditioning"] = [x]
n = [t[0], o]
c.append(n)
return (c, )
class GLIGENLoader:
@@ -2062,7 +2070,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images",
"ImageCrop": "Image Crop",
"ImageStitch": "Image Stitch",
"ImageBlend": "Image Blend",
"ImageBlur": "Image Blur",
"ImageQuantize": "Image Quantize",
@@ -2130,7 +2137,6 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
if os.path.isdir(web_dir):
EXTENSION_WEB_DIRS[module_name] = web_dir
# V1 node definition
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
if name not in ignore:
@@ -2139,19 +2145,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
# V3 node definition
elif getattr(module, "NODES_LIST", None) is not None:
for node_cls in module.NODES_LIST:
node_cls: ComfyNodeV3
schema = node_cls.GET_SCHEMA()
if schema.node_id not in ignore:
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if schema.display_name is not None:
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
return True
else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e:
logging.warning(traceback.format_exc())
@@ -2271,7 +2266,6 @@ def init_builtin_extra_nodes():
"nodes_ace.py",
"nodes_string.py",
"nodes_camera_trajectory.py",
"nodes_v3_test.py",
]
import_failed = []
@@ -2296,10 +2290,6 @@ def init_builtin_api_nodes():
"nodes_pixverse.py",
"nodes_stability.py",
"nodes_pika.py",
"nodes_runway.py",
"nodes_tripo.py",
"nodes_rodin.py",
"nodes_gemini.py",
]
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):

904
openapi.yaml Normal file
View File

@@ -0,0 +1,904 @@
openapi: 3.0.3
info:
title: ComfyUI API
description: |
API for ComfyUI - A powerful and modular UI for Stable Diffusion.
This API allows you to interact with ComfyUI programmatically, including:
- Submitting workflows for execution
- Managing the execution queue
- Retrieving generated images
- Managing models
- Retrieving node information
version: 1.0.0
license:
name: GNU General Public License v3.0
url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE
servers:
- url: /
description: Default ComfyUI server
tags:
- name: workflow
description: Workflow execution and management
- name: queue
description: Queue management
- name: image
description: Image handling
- name: node
description: Node information
- name: model
description: Model management
- name: system
description: System information
- name: internal
description: Internal API routes
paths:
/prompt:
get:
tags:
- workflow
summary: Get information about current prompt execution
description: Returns information about the current prompt in the execution queue
operationId: getPromptInfo
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/PromptInfo'
post:
tags:
- workflow
summary: Submit a workflow for execution
description: |
Submit a workflow to be executed by the backend.
The workflow is a JSON object describing the nodes and their connections.
operationId: executePrompt
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PromptRequest'
responses:
'200':
description: Success - Prompt accepted
content:
application/json:
schema:
$ref: '#/components/schemas/PromptResponse'
'400':
description: Invalid prompt
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
/queue:
get:
tags:
- queue
summary: Get queue information
description: Returns information about running and pending items in the queue
operationId: getQueueInfo
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/QueueInfo'
post:
tags:
- queue
summary: Manage queue
description: Clear the queue or delete specific items
operationId: manageQueue
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
clear:
type: boolean
description: If true, clears the entire queue
delete:
type: array
description: Array of prompt IDs to delete from the queue
items:
type: string
format: uuid
responses:
'200':
description: Success
/interrupt:
post:
tags:
- workflow
summary: Interrupt the current execution
description: Interrupts the currently running workflow execution
operationId: interruptExecution
responses:
'200':
description: Success
/free:
post:
tags:
- system
summary: Free resources
description: Unload models and/or free memory
operationId: freeResources
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
unload_models:
type: boolean
description: If true, unloads models from memory
free_memory:
type: boolean
description: If true, frees GPU memory
responses:
'200':
description: Success
/history:
get:
tags:
- workflow
summary: Get execution history
description: Returns the history of executed workflows
operationId: getHistory
parameters:
- name: max_items
in: query
description: Maximum number of history items to return
required: false
schema:
type: integer
format: int32
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/HistoryItem'
post:
tags:
- workflow
summary: Manage history
description: Clear history or delete specific items
operationId: manageHistory
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
clear:
type: boolean
description: If true, clears the entire history
delete:
type: array
description: Array of prompt IDs to delete from history
items:
type: string
format: uuid
responses:
'200':
description: Success
/history/{prompt_id}:
get:
tags:
- workflow
summary: Get specific history item
description: Returns a specific history item by ID
operationId: getHistoryItem
parameters:
- name: prompt_id
in: path
description: ID of the prompt to retrieve
required: true
schema:
type: string
format: uuid
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/HistoryItem'
/object_info:
get:
tags:
- node
summary: Get all node information
description: Returns information about all available nodes
operationId: getNodeInfo
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
additionalProperties:
$ref: '#/components/schemas/NodeInfo'
/object_info/{node_class}:
get:
tags:
- node
summary: Get specific node information
description: Returns information about a specific node class
operationId: getNodeClassInfo
parameters:
- name: node_class
in: path
description: Name of the node class
required: true
schema:
type: string
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
additionalProperties:
$ref: '#/components/schemas/NodeInfo'
/upload/image:
post:
tags:
- image
summary: Upload an image
description: Uploads an image to the server
operationId: uploadImage
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
image:
type: string
format: binary
description: The image file to upload
overwrite:
type: string
description: Whether to overwrite if file exists (true/false)
type:
type: string
enum: [input, temp, output]
description: Type of directory to store the image in
subfolder:
type: string
description: Subfolder to store the image in
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
properties:
name:
type: string
description: Filename of the uploaded image
subfolder:
type: string
description: Subfolder the image was stored in
type:
type: string
description: Type of directory the image was stored in
'400':
description: Bad request
/upload/mask:
post:
tags:
- image
summary: Upload a mask for an image
description: Uploads a mask image and applies it to a referenced original image
operationId: uploadMask
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
image:
type: string
format: binary
description: The mask image file to upload
original_ref:
type: string
description: JSON string containing reference to the original image
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
properties:
name:
type: string
description: Filename of the uploaded mask
subfolder:
type: string
description: Subfolder the mask was stored in
type:
type: string
description: Type of directory the mask was stored in
'400':
description: Bad request
/view:
get:
tags:
- image
summary: View an image
description: Retrieves an image from the server
operationId: viewImage
parameters:
- name: filename
in: query
description: Name of the file to retrieve
required: true
schema:
type: string
- name: type
in: query
description: Type of directory to retrieve from
required: false
schema:
type: string
enum: [input, temp, output]
default: output
- name: subfolder
in: query
description: Subfolder to retrieve from
required: false
schema:
type: string
- name: preview
in: query
description: Preview options (format;quality)
required: false
schema:
type: string
- name: channel
in: query
description: Channel to retrieve (rgb, a, rgba)
required: false
schema:
type: string
enum: [rgb, a, rgba]
default: rgba
responses:
'200':
description: Success
content:
image/*:
schema:
type: string
format: binary
'400':
description: Bad request
'404':
description: File not found
/view_metadata/{folder_name}:
get:
tags:
- model
summary: View model metadata
description: Retrieves metadata from a safetensors file
operationId: viewModelMetadata
parameters:
- name: folder_name
in: path
description: Name of the model folder
required: true
schema:
type: string
- name: filename
in: query
description: Name of the safetensors file
required: true
schema:
type: string
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
'404':
description: File not found
/models:
get:
tags:
- model
summary: Get model types
description: Returns a list of available model types
operationId: getModelTypes
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
/models/{folder}:
get:
tags:
- model
summary: Get models of a specific type
description: Returns a list of available models of a specific type
operationId: getModels
parameters:
- name: folder
in: path
description: Model type folder
required: true
schema:
type: string
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
'404':
description: Folder not found
/embeddings:
get:
tags:
- model
summary: Get embeddings
description: Returns a list of available embeddings
operationId: getEmbeddings
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
/extensions:
get:
tags:
- system
summary: Get extensions
description: Returns a list of available extensions
operationId: getExtensions
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
/system_stats:
get:
tags:
- system
summary: Get system statistics
description: Returns system information including RAM, VRAM, and ComfyUI version
operationId: getSystemStats
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/SystemStats'
/ws:
get:
tags:
- workflow
summary: WebSocket connection
description: |
Establishes a WebSocket connection for real-time communication.
This endpoint is used for receiving progress updates, status changes, and results from workflow executions.
operationId: webSocketConnect
parameters:
- name: clientId
in: query
description: Optional client ID for reconnection
required: false
schema:
type: string
responses:
'101':
description: Switching Protocols to WebSocket
/internal/logs:
get:
tags:
- internal
summary: Get logs
description: Returns system logs as a single string
operationId: getLogs
responses:
'200':
description: Success
content:
application/json:
schema:
type: string
/internal/logs/raw:
get:
tags:
- internal
summary: Get raw logs
description: Returns raw system logs with terminal size information
operationId: getRawLogs
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
properties:
entries:
type: array
items:
type: object
properties:
t:
type: string
description: Timestamp
m:
type: string
description: Message
size:
type: object
properties:
cols:
type: integer
description: Terminal columns
rows:
type: integer
description: Terminal rows
/internal/logs/subscribe:
patch:
tags:
- internal
summary: Subscribe to logs
description: Subscribe or unsubscribe to log updates
operationId: subscribeToLogs
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
clientId:
type: string
description: Client ID
enabled:
type: boolean
description: Whether to enable or disable subscription
responses:
'200':
description: Success
/internal/folder_paths:
get:
tags:
- internal
summary: Get folder paths
description: Returns a map of folder names to their paths
operationId: getFolderPaths
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
additionalProperties:
type: string
/internal/files/{directory_type}:
get:
tags:
- internal
summary: Get files
description: Returns a list of files in a specific directory type
operationId: getFiles
parameters:
- name: directory_type
in: path
description: Type of directory (output, input, temp)
required: true
schema:
type: string
enum: [output, input, temp]
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
'400':
description: Invalid directory type
components:
schemas:
PromptRequest:
type: object
required:
- prompt
properties:
prompt:
type: object
description: The workflow graph to execute
additionalProperties: true
number:
type: number
description: Priority number for the queue (lower numbers have higher priority)
front:
type: boolean
description: If true, adds the prompt to the front of the queue
extra_data:
type: object
description: Extra data to be associated with the prompt
additionalProperties: true
client_id:
type: string
description: Client ID for attribution of the prompt
PromptResponse:
type: object
properties:
prompt_id:
type: string
format: uuid
description: Unique identifier for the prompt execution
number:
type: number
description: Priority number in the queue
node_errors:
type: object
description: Any errors in the nodes of the prompt
additionalProperties: true
ErrorResponse:
type: object
properties:
error:
type: object
properties:
type:
type: string
description: Error type
message:
type: string
description: Error message
details:
type: string
description: Detailed error information
extra_info:
type: object
description: Additional error information
additionalProperties: true
node_errors:
type: object
description: Node-specific errors
additionalProperties: true
PromptInfo:
type: object
properties:
exec_info:
type: object
properties:
queue_remaining:
type: integer
description: Number of items remaining in the queue
QueueInfo:
type: object
properties:
queue_running:
type: array
items:
type: object
description: Currently running items
additionalProperties: true
queue_pending:
type: array
items:
type: object
description: Pending items in the queue
additionalProperties: true
HistoryItem:
type: object
properties:
prompt_id:
type: string
format: uuid
description: Unique identifier for the prompt
prompt:
type: object
description: The workflow graph that was executed
additionalProperties: true
extra_data:
type: object
description: Additional data associated with the execution
additionalProperties: true
outputs:
type: object
description: Output data from the execution
additionalProperties: true
NodeInfo:
type: object
properties:
input:
type: object
description: Input specifications for the node
additionalProperties: true
input_order:
type: object
description: Order of inputs for display
additionalProperties:
type: array
items:
type: string
output:
type: array
items:
type: string
description: Output types of the node
output_is_list:
type: array
items:
type: boolean
description: Whether each output is a list
output_name:
type: array
items:
type: string
description: Names of the outputs
name:
type: string
description: Internal name of the node
display_name:
type: string
description: Display name of the node
description:
type: string
description: Description of the node
python_module:
type: string
description: Python module implementing the node
category:
type: string
description: Category of the node
output_node:
type: boolean
description: Whether this is an output node
output_tooltips:
type: array
items:
type: string
description: Tooltips for outputs
deprecated:
type: boolean
description: Whether the node is deprecated
experimental:
type: boolean
description: Whether the node is experimental
api_node:
type: boolean
description: Whether this is an API node
SystemStats:
type: object
properties:
system:
type: object
properties:
os:
type: string
description: Operating system
ram_total:
type: number
description: Total system RAM in bytes
ram_free:
type: number
description: Free system RAM in bytes
comfyui_version:
type: string
description: ComfyUI version
python_version:
type: string
description: Python version
pytorch_version:
type: string
description: PyTorch version
embedded_python:
type: boolean
description: Whether using embedded Python
argv:
type: array
items:
type: string
description: Command line arguments
devices:
type: array
items:
type: object
properties:
name:
type: string
description: Device name
type:
type: string
description: Device type
index:
type: integer
description: Device index
vram_total:
type: number
description: Total VRAM in bytes
vram_free:
type: number
description: Free VRAM in bytes
torch_vram_total:
type: number
description: Total VRAM as reported by PyTorch
torch_vram_free:
type: number
description: Free VRAM as reported by PyTorch

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.39"
version = "0.3.34"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,6 +1,5 @@
comfyui-frontend-package==1.21.3
comfyui-workflow-templates==0.1.23
comfyui-embedded-docs==0.2.0
comfyui-frontend-package==1.19.9
comfyui-workflow-templates==0.1.14
torch
torchsde
torchvision

View File

@@ -29,8 +29,6 @@ import comfy.model_management
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager
from comfy_api.v3.io import ComfyNodeV3
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
@@ -161,7 +159,7 @@ class PromptServer():
self.custom_node_manager = CustomNodeManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self)
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
@@ -228,7 +226,7 @@ class PromptServer():
return response
@routes.get("/embeddings")
def get_embeddings(request):
def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@@ -284,6 +282,7 @@ class PromptServer():
a.update(f.read())
b.update(image.file.read())
image.file.seek(0)
f.close()
return a.hexdigest() == b.hexdigest()
return False
@@ -556,8 +555,6 @@ class PromptServer():
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
if isinstance(obj_class, ComfyNodeV3):
return obj_class.GET_NODE_INFO_V1()
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
@@ -624,7 +621,7 @@ class PromptServer():
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile()
current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@@ -749,13 +746,6 @@ class PromptServer():
web.static('/templates', workflow_templates_path)
])
# Serve embedded documentation from the package
embedded_docs_path = FrontendManager.embedded_docs_path()
if embedded_docs_path:
self.app.add_routes([
web.static('/docs', embedded_docs_path)
])
self.app.add_routes([
web.static('/', self.web_root),
])

74
tests-api/README.md Normal file
View File

@@ -0,0 +1,74 @@
# ComfyUI API Testing
This directory contains tests for validating the ComfyUI OpenAPI specification against a running instance of ComfyUI.
## Setup
1. Install the required dependencies:
```bash
pip install -r requirements.txt
```
2. Make sure you have a running instance of ComfyUI (default: http://127.0.0.1:8188)
## Running the Tests
Run all tests with pytest:
```bash
cd tests-api
pytest
```
Run specific test files:
```bash
pytest test_spec_validation.py
pytest test_endpoint_existence.py
pytest test_schema_validation.py
pytest test_api_by_tag.py
```
Run tests with more verbose output:
```bash
pytest -v
```
## Test Categories
The tests are organized into several categories:
1. **Spec Validation**: Validates that the OpenAPI specification is valid.
2. **Endpoint Existence**: Tests that the endpoints defined in the spec exist on the server.
3. **Schema Validation**: Tests that the server responses match the schemas defined in the spec.
4. **Tag-Based Tests**: Tests that the API's tag organization is consistent.
## Using a Different Server
By default, the tests connect to `http://127.0.0.1:8188`. To test against a different server, set the `COMFYUI_SERVER_URL` environment variable:
```bash
COMFYUI_SERVER_URL=http://example.com:8188 pytest
```
## Test Structure
- `conftest.py`: Contains pytest fixtures used by the tests.
- `utils/`: Contains utility functions for working with the OpenAPI spec.
- `test_*.py`: The actual test files.
- `resources/`: Contains resources used by the tests (e.g., sample workflows).
## Extending the Tests
To add new tests:
1. For testing new endpoints, add them to the appropriate test file based on their category.
2. For testing more complex functionality, create a new test file following the established patterns.
## Notes
- Tests that require a running server will be skipped if the server is not available.
- Some tests may fail if the server doesn't match the specification exactly.
- The tests don't modify any data on the server (they're read-only).

141
tests-api/conftest.py Normal file
View File

@@ -0,0 +1,141 @@
"""
Test fixtures for API testing
"""
import os
import pytest
import yaml
import requests
import logging
from typing import Dict, Any, Generator, Optional
from urllib.parse import urljoin
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default server configuration
DEFAULT_SERVER_URL = "http://127.0.0.1:8188"
@pytest.fixture(scope="session")
def api_spec_path() -> str:
"""
Get the path to the OpenAPI specification file
Returns:
Path to the OpenAPI specification file
"""
return os.path.abspath(os.path.join(
os.path.dirname(__file__),
"..",
"openapi.yaml"
))
@pytest.fixture(scope="session")
def api_spec(api_spec_path: str) -> Dict[str, Any]:
"""
Load the OpenAPI specification
Args:
api_spec_path: Path to the spec file
Returns:
Parsed OpenAPI specification
"""
with open(api_spec_path, 'r') as f:
return yaml.safe_load(f)
@pytest.fixture(scope="session")
def base_url() -> str:
"""
Get the base URL for the API server
Returns:
Base URL string
"""
# Allow overriding via environment variable
return os.environ.get("COMFYUI_SERVER_URL", DEFAULT_SERVER_URL)
@pytest.fixture(scope="session")
def server_available(base_url: str) -> bool:
"""
Check if the server is available
Args:
base_url: Base URL for the API
Returns:
True if the server is available, False otherwise
"""
try:
response = requests.get(base_url, timeout=2)
return response.status_code == 200
except requests.RequestException:
logger.warning(f"Server at {base_url} is not available")
return False
@pytest.fixture
def api_client(base_url: str) -> Generator[Optional[requests.Session], None, None]:
"""
Create a requests session for API testing
Args:
base_url: Base URL for the API
Yields:
Requests session configured for the API
"""
session = requests.Session()
# Helper function to construct URLs
def get_url(path: str) -> str:
return urljoin(base_url, path)
# Add url helper to the session
session.get_url = get_url # type: ignore
yield session
# Cleanup
session.close()
@pytest.fixture
def api_get_json(api_client: requests.Session):
"""
Helper fixture for making GET requests and parsing JSON responses
Args:
api_client: API client session
Returns:
Function that makes GET requests and returns JSON
"""
def _get_json(path: str, **kwargs):
url = api_client.get_url(path) # type: ignore
response = api_client.get(url, **kwargs)
if response.status_code == 200:
try:
return response.json()
except ValueError:
return None
return None
return _get_json
@pytest.fixture
def require_server(server_available):
"""
Skip tests if server is not available
Args:
server_available: Whether the server is available
"""
if not server_available:
pytest.skip("Server is not available")

View File

@@ -0,0 +1,6 @@
pytest>=7.0.0
pytest-asyncio>=0.21.0
openapi-spec-validator>=0.5.0
jsonschema>=4.17.0
requests>=2.28.0
pyyaml>=6.0.0

View File

@@ -0,0 +1,279 @@
"""
Tests for API endpoints grouped by tags
"""
import pytest
import logging
import sys
import os
from typing import Dict, Any, Set
# Use a direct import with the full path
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# Define functions inline to avoid import issues
def get_all_endpoints(spec):
"""
Extract all endpoints from an OpenAPI spec
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
endpoints.append({
'path': path,
'method': method.lower(),
'tags': operation.get('tags', []),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints
def get_all_tags(spec):
"""
Get all tags used in the API spec
"""
tags = set()
for path_item in spec['paths'].values():
for operation in path_item.values():
if isinstance(operation, dict) and 'tags' in operation:
tags.update(operation['tags'])
return tags
def extract_endpoints_by_tag(spec, tag):
"""
Extract all endpoints with a specific tag
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
if tag in operation.get('tags', []):
endpoints.append({
'path': path,
'method': method.lower(),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.fixture
def api_tags(api_spec: Dict[str, Any]) -> Set[str]:
"""
Get all tags from the API spec
Args:
api_spec: Loaded OpenAPI spec
Returns:
Set of tag names
"""
return get_all_tags(api_spec)
def test_api_has_tags(api_tags: Set[str]):
"""
Test that the API has defined tags
Args:
api_tags: Set of tags
"""
assert len(api_tags) > 0, "API spec should have at least one tag"
# Log the tags
logger.info(f"API spec has the following tags: {sorted(api_tags)}")
@pytest.mark.parametrize("tag", [
"workflow",
"image",
"model",
"node",
"system"
])
def test_core_tags_exist(api_tags: Set[str], tag: str):
"""
Test that core tags exist in the API spec
Args:
api_tags: Set of tags
tag: Tag to check
"""
assert tag in api_tags, f"API spec should have '{tag}' tag"
def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'workflow' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "workflow")
assert len(endpoints) > 0, "No endpoints found with 'workflow' tag"
# Check for key workflow endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/prompt" in endpoint_paths, "Workflow tag should include /prompt endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'workflow' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_image_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'image' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "image")
assert len(endpoints) > 0, "No endpoints found with 'image' tag"
# Check for key image endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/upload/image" in endpoint_paths, "Image tag should include /upload/image endpoint"
assert "/view" in endpoint_paths, "Image tag should include /view endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'image' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_model_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'model' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "model")
assert len(endpoints) > 0, "No endpoints found with 'model' tag"
# Check for key model endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/models" in endpoint_paths, "Model tag should include /models endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'model' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_node_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'node' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "node")
assert len(endpoints) > 0, "No endpoints found with 'node' tag"
# Check for key node endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/object_info" in endpoint_paths, "Node tag should include /object_info endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'node' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_system_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'system' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "system")
assert len(endpoints) > 0, "No endpoints found with 'system' tag"
# Check for key system endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/system_stats" in endpoint_paths, "System tag should include /system_stats endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'system' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'internal' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "internal")
assert len(endpoints) > 0, "No endpoints found with 'internal' tag"
# Check for key internal endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/internal/logs" in endpoint_paths, "Internal tag should include /internal/logs endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'internal' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_operation_ids_match_tag(api_spec: Dict[str, Any]):
"""
Test that operation IDs follow a consistent pattern with their tag
Args:
api_spec: Loaded OpenAPI spec
"""
failures = []
for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' in operation and 'tags' in operation and operation['tags']:
op_id = operation['operationId']
primary_tag = operation['tags'][0].lower()
# Check if operationId starts with primary tag prefix
# This is a common convention, but might need adjusting
if not (op_id.startswith(primary_tag) or
any(op_id.lower().startswith(f"{tag.lower()}") for tag in operation['tags'])):
failures.append({
'path': path,
'method': method,
'operationId': op_id,
'primary_tag': primary_tag
})
# Log failures for diagnosis but don't fail the test
# as this is a style/convention check
if failures:
logger.warning(f"Found {len(failures)} operationIds that don't align with their tags:")
for f in failures:
logger.warning(f" {f['method'].upper()} {f['path']} - operationId: {f['operationId']}, primary tag: {f['primary_tag']}")

View File

@@ -0,0 +1,240 @@
"""
Tests for endpoint existence and basic response codes
"""
import pytest
import requests
import logging
import sys
import os
from typing import Dict, Any, List
from urllib.parse import urljoin
# Use a direct import with the full path
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# Define get_all_endpoints function inline to avoid import issues
def get_all_endpoints(spec):
"""
Extract all endpoints from an OpenAPI spec
Args:
spec: Parsed OpenAPI specification
Returns:
List of dicts with path, method, and tags for each endpoint
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
endpoints.append({
'path': path,
'method': method.lower(),
'tags': operation.get('tags', []),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.fixture
def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Get all endpoints from the API spec
Args:
api_spec: Loaded OpenAPI spec
Returns:
List of endpoint information
"""
return get_all_endpoints(api_spec)
def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]):
"""
Test that endpoints are defined in the spec
Args:
all_endpoints: List of endpoint information
"""
# Simple check that we have endpoints defined
assert len(all_endpoints) > 0, "No endpoints defined in the OpenAPI spec"
# Log the endpoints for informational purposes
logger.info(f"Found {len(all_endpoints)} endpoints in the OpenAPI spec")
for endpoint in all_endpoints:
logger.info(f"{endpoint['method'].upper()} {endpoint['path']} - {endpoint['summary']}")
@pytest.mark.parametrize("endpoint_path", [
"/", # Root path
"/prompt", # Get prompt info
"/queue", # Get queue
"/models", # Get model types
"/object_info", # Get node info
"/system_stats" # Get system stats
])
def test_basic_get_endpoints(require_server, api_client, endpoint_path: str):
"""
Test that basic GET endpoints exist and respond
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
endpoint_path: Path to test
"""
url = api_client.get_url(endpoint_path) # type: ignore
try:
response = api_client.get(url)
# We're just checking that the endpoint exists and returns some kind of response
# Not necessarily a 200 status code
assert response.status_code not in [404, 405], f"Endpoint {endpoint_path} does not exist"
logger.info(f"Endpoint {endpoint_path} exists with status code {response.status_code}")
except requests.RequestException as e:
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
def test_websocket_endpoint_exists(require_server, base_url: str):
"""
Test that the WebSocket endpoint exists
Args:
require_server: Fixture that skips if server is not available
base_url: Base server URL
"""
ws_url = urljoin(base_url, "/ws")
# For WebSocket, we can't use a normal GET request
# Instead, we make a HEAD request to check if the endpoint exists
try:
response = requests.head(ws_url)
# WebSocket endpoints often return a 400 Bad Request for HEAD requests
# but a 404 would indicate the endpoint doesn't exist
assert response.status_code != 404, "WebSocket endpoint /ws does not exist"
logger.info(f"WebSocket endpoint exists with status code {response.status_code}")
except requests.RequestException as e:
pytest.fail(f"Request to WebSocket endpoint failed: {str(e)}")
def test_api_models_folder_endpoint(require_server, api_client):
"""
Test that the /models/{folder} endpoint exists and responds
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
"""
# First get available model types
models_url = api_client.get_url("/models") # type: ignore
try:
models_response = api_client.get(models_url)
assert models_response.status_code == 200, "Failed to get model types"
model_types = models_response.json()
# Skip if no model types available
if not model_types:
pytest.skip("No model types available to test")
# Test with the first model type
model_type = model_types[0]
models_folder_url = api_client.get_url(f"/models/{model_type}") # type: ignore
folder_response = api_client.get(models_folder_url)
# We're just checking that the endpoint exists
assert folder_response.status_code != 404, f"Endpoint /models/{model_type} does not exist"
logger.info(f"Endpoint /models/{model_type} exists with status code {folder_response.status_code}")
except requests.RequestException as e:
pytest.fail(f"Request failed: {str(e)}")
except (ValueError, KeyError, IndexError) as e:
pytest.fail(f"Failed to process response: {str(e)}")
def test_api_object_info_node_endpoint(require_server, api_client):
"""
Test that the /object_info/{node_class} endpoint exists and responds
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
"""
# First get available node classes
objects_url = api_client.get_url("/object_info") # type: ignore
try:
objects_response = api_client.get(objects_url)
assert objects_response.status_code == 200, "Failed to get object info"
node_classes = objects_response.json()
# Skip if no node classes available
if not node_classes:
pytest.skip("No node classes available to test")
# Test with the first node class
node_class = next(iter(node_classes.keys()))
node_url = api_client.get_url(f"/object_info/{node_class}") # type: ignore
node_response = api_client.get(node_url)
# We're just checking that the endpoint exists
assert node_response.status_code != 404, f"Endpoint /object_info/{node_class} does not exist"
logger.info(f"Endpoint /object_info/{node_class} exists with status code {node_response.status_code}")
except requests.RequestException as e:
pytest.fail(f"Request failed: {str(e)}")
except (ValueError, KeyError, StopIteration) as e:
pytest.fail(f"Failed to process response: {str(e)}")
def test_internal_endpoints_exist(require_server, api_client):
"""
Test that internal endpoints exist
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
"""
internal_endpoints = [
"/internal/logs",
"/internal/logs/raw",
"/internal/folder_paths",
"/internal/files/output"
]
for endpoint in internal_endpoints:
url = api_client.get_url(endpoint) # type: ignore
try:
response = api_client.get(url)
# We're just checking that the endpoint exists
assert response.status_code != 404, f"Endpoint {endpoint} does not exist"
logger.info(f"Endpoint {endpoint} exists with status code {response.status_code}")
except requests.RequestException as e:
logger.warning(f"Request to {endpoint} failed: {str(e)}")
# Don't fail the test as internal endpoints might be restricted

View File

@@ -0,0 +1,440 @@
"""
Tests for validating API responses against OpenAPI schema
"""
import pytest
import requests
import logging
import sys
import os
import json
from typing import Dict, Any
# Use a direct import with the full path
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# Define validation functions inline to avoid import issues
def get_endpoint_schema(
spec,
path,
method,
status_code = '200'
):
"""
Extract response schema for a specific endpoint from OpenAPI spec
"""
method = method.lower()
# Handle path not found
if path not in spec['paths']:
return None
# Handle method not found
if method not in spec['paths'][path]:
return None
# Handle status code not found
responses = spec['paths'][path][method].get('responses', {})
if status_code not in responses:
return None
# Handle no content defined
if 'content' not in responses[status_code]:
return None
# Get schema from first content type
content_types = responses[status_code]['content']
first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]:
return None
return content_types[first_content_type]['schema']
def resolve_schema_refs(schema, spec):
"""
Resolve $ref references in a schema
"""
if not isinstance(schema, dict):
return schema
result = {}
for key, value in schema.items():
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
# Handle reference
ref_path = value[2:].split('/')
ref_value = spec
for path_part in ref_path:
ref_value = ref_value.get(path_part, {})
# Recursively resolve any refs in the referenced schema
ref_value = resolve_schema_refs(ref_value, spec)
result.update(ref_value)
elif isinstance(value, dict):
# Recursively resolve refs in nested dictionaries
result[key] = resolve_schema_refs(value, spec)
elif isinstance(value, list):
# Recursively resolve refs in list items
result[key] = [
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
for item in value
]
else:
# Pass through other values
result[key] = value
return result
def validate_response(
response_data,
spec,
path,
method,
status_code = '200'
):
"""
Validate a response against the OpenAPI schema
"""
schema = get_endpoint_schema(spec, path, method, status_code)
if schema is None:
return {
'valid': False,
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
}
# Resolve any $ref in the schema
resolved_schema = resolve_schema_refs(schema, spec)
try:
import jsonschema
jsonschema.validate(instance=response_data, schema=resolved_schema)
return {'valid': True, 'errors': []}
except jsonschema.exceptions.ValidationError as e:
# Extract more detailed error information
path = ".".join(str(p) for p in e.path) if e.path else "root"
instance = e.instance if not isinstance(e.instance, dict) else "..."
schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown"
detailed_error = (
f"Validation error at path: {path}\n"
f"Schema path: {schema_path}\n"
f"Error message: {e.message}\n"
f"Failed instance: {instance}\n"
)
return {'valid': False, 'errors': [detailed_error]}
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.mark.parametrize("endpoint_path,method", [
("/system_stats", "get"),
("/prompt", "get"),
("/queue", "get"),
("/models", "get"),
("/embeddings", "get")
])
def test_response_schema_validation(
require_server,
api_client,
api_spec: Dict[str, Any],
endpoint_path: str,
method: str
):
"""
Test that API responses match the defined schema
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
endpoint_path: Path to test
method: HTTP method to test
"""
url = api_client.get_url(endpoint_path) # type: ignore
# Skip if no schema defined
schema = get_endpoint_schema(api_spec, endpoint_path, method)
if not schema:
pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}")
try:
if method.lower() == "get":
response = api_client.get(url)
else:
pytest.skip(f"Method {method} not implemented for automated testing")
return
# Skip if response is not 200
if response.status_code != 200:
pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}")
return
# Skip if response is not JSON
try:
response_data = response.json()
except ValueError:
pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON")
return
# Validate the response
validation_result = validate_response(
response_data,
api_spec,
endpoint_path,
method
)
if validation_result['valid']:
logger.info(f"Response from {method.upper()} {endpoint_path} matches schema")
else:
for error in validation_result['errors']:
logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}")
assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]):
"""
Test the system_stats endpoint response in detail
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
"""
url = api_client.get_url("/system_stats") # type: ignore
try:
response = api_client.get(url)
assert response.status_code == 200, "Failed to get system stats"
# Parse response
stats = response.json()
# Validate high-level structure
assert 'system' in stats, "Response missing 'system' field"
assert 'devices' in stats, "Response missing 'devices' field"
# Validate system fields
system = stats['system']
assert 'os' in system, "System missing 'os' field"
assert 'ram_total' in system, "System missing 'ram_total' field"
assert 'ram_free' in system, "System missing 'ram_free' field"
assert 'comfyui_version' in system, "System missing 'comfyui_version' field"
# Validate devices fields
devices = stats['devices']
assert isinstance(devices, list), "Devices should be a list"
if devices:
device = devices[0]
assert 'name' in device, "Device missing 'name' field"
assert 'type' in device, "Device missing 'type' field"
assert 'vram_total' in device, "Device missing 'vram_total' field"
assert 'vram_free' in device, "Device missing 'vram_free' field"
# Perform schema validation
validation_result = validate_response(
stats,
api_spec,
"/system_stats",
"get"
)
# Print detailed error if validation fails
if not validation_result['valid']:
for error in validation_result['errors']:
logger.error(f"Validation error for /system_stats: {error}")
# Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/system_stats", "get")
if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print sample of the response
logger.error(f"Response:\n{json.dumps(stats, indent=2)}")
assert validation_result['valid'], "System stats response does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to /system_stats failed: {str(e)}")
def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]):
"""
Test the models endpoint response
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
"""
url = api_client.get_url("/models") # type: ignore
try:
response = api_client.get(url)
assert response.status_code == 200, "Failed to get models"
# Parse response
models = response.json()
# Validate it's a list
assert isinstance(models, list), "Models response should be a list"
# Each item should be a string
for model in models:
assert isinstance(model, str), "Each model type should be a string"
# Perform schema validation
validation_result = validate_response(
models,
api_spec,
"/models",
"get"
)
# Print detailed error if validation fails
if not validation_result['valid']:
for error in validation_result['errors']:
logger.error(f"Validation error for /models: {error}")
# Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/models", "get")
if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print response
sample_models = models[:5] if isinstance(models, list) else models
logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}")
assert validation_result['valid'], "Models response does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to /models failed: {str(e)}")
def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]):
"""
Test the object_info endpoint response
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
"""
url = api_client.get_url("/object_info") # type: ignore
try:
response = api_client.get(url)
assert response.status_code == 200, "Failed to get object info"
# Parse response
objects = response.json()
# Validate it's an object
assert isinstance(objects, dict), "Object info response should be an object"
# Check if we have any objects
if objects:
# Get the first object
first_obj_name = next(iter(objects.keys()))
first_obj = objects[first_obj_name]
# Validate first object has required fields
assert 'input' in first_obj, "Object missing 'input' field"
assert 'output' in first_obj, "Object missing 'output' field"
assert 'name' in first_obj, "Object missing 'name' field"
# Perform schema validation
validation_result = validate_response(
objects,
api_spec,
"/object_info",
"get"
)
# Print detailed error if validation fails
if not validation_result['valid']:
for error in validation_result['errors']:
logger.error(f"Validation error for /object_info: {error}")
# Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/object_info", "get")
if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Also print a small sample of the response
sample = dict(list(objects.items())[:1]) if objects else {}
logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}")
assert validation_result['valid'], "Object info response does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to /object_info failed: {str(e)}")
except (KeyError, StopIteration) as e:
pytest.fail(f"Failed to process response: {str(e)}")
def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]):
"""
Test the queue endpoint response
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
"""
url = api_client.get_url("/queue") # type: ignore
try:
response = api_client.get(url)
assert response.status_code == 200, "Failed to get queue"
# Parse response
queue = response.json()
# Validate structure
assert 'queue_running' in queue, "Queue missing 'queue_running' field"
assert 'queue_pending' in queue, "Queue missing 'queue_pending' field"
# Each should be a list
assert isinstance(queue['queue_running'], list), "queue_running should be a list"
assert isinstance(queue['queue_pending'], list), "queue_pending should be a list"
# Perform schema validation
validation_result = validate_response(
queue,
api_spec,
"/queue",
"get"
)
# Print detailed error if validation fails
if not validation_result['valid']:
for error in validation_result['errors']:
logger.error(f"Validation error for /queue: {error}")
# Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/queue", "get")
if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print response
logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}")
assert validation_result['valid'], "Queue response does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to /queue failed: {str(e)}")

View File

@@ -0,0 +1,144 @@
"""
Tests for validating the OpenAPI specification
"""
import pytest
from openapi_spec_validator import validate_spec
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
from typing import Dict, Any
def test_openapi_spec_is_valid(api_spec: Dict[str, Any]):
"""
Test that the OpenAPI specification is valid
Args:
api_spec: Loaded OpenAPI spec
"""
try:
validate_spec(api_spec)
except OpenAPISpecValidatorError as e:
pytest.fail(f"OpenAPI spec validation failed: {str(e)}")
def test_spec_has_info(api_spec: Dict[str, Any]):
"""
Test that the OpenAPI spec has the required info section
Args:
api_spec: Loaded OpenAPI spec
"""
assert 'info' in api_spec, "Spec must have info section"
assert 'title' in api_spec['info'], "Info must have title"
assert 'version' in api_spec['info'], "Info must have version"
def test_spec_has_paths(api_spec: Dict[str, Any]):
"""
Test that the OpenAPI spec has paths defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert 'paths' in api_spec, "Spec must have paths section"
assert len(api_spec['paths']) > 0, "Spec must have at least one path"
def test_spec_has_components(api_spec: Dict[str, Any]):
"""
Test that the OpenAPI spec has components defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert 'components' in api_spec, "Spec must have components section"
assert 'schemas' in api_spec['components'], "Components must have schemas"
def test_workflow_endpoints_exist(api_spec: Dict[str, Any]):
"""
Test that core workflow endpoints are defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert '/prompt' in api_spec['paths'], "Spec must define /prompt endpoint"
assert 'post' in api_spec['paths']['/prompt'], "Spec must define POST /prompt"
assert 'get' in api_spec['paths']['/prompt'], "Spec must define GET /prompt"
def test_image_endpoints_exist(api_spec: Dict[str, Any]):
"""
Test that core image endpoints are defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert '/upload/image' in api_spec['paths'], "Spec must define /upload/image endpoint"
assert '/view' in api_spec['paths'], "Spec must define /view endpoint"
def test_model_endpoints_exist(api_spec: Dict[str, Any]):
"""
Test that core model endpoints are defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert '/models' in api_spec['paths'], "Spec must define /models endpoint"
assert '/models/{folder}' in api_spec['paths'], "Spec must define /models/{folder} endpoint"
def test_operation_ids_are_unique(api_spec: Dict[str, Any]):
"""
Test that all operationIds are unique
Args:
api_spec: Loaded OpenAPI spec
"""
operation_ids = []
for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' in operation:
operation_ids.append(operation['operationId'])
# Check for duplicates
duplicates = set([op_id for op_id in operation_ids if operation_ids.count(op_id) > 1])
assert len(duplicates) == 0, f"Found duplicate operationIds: {duplicates}"
def test_all_endpoints_have_operation_ids(api_spec: Dict[str, Any]):
"""
Test that all endpoints have operationIds
Args:
api_spec: Loaded OpenAPI spec
"""
missing = []
for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' not in operation:
missing.append(f"{method.upper()} {path}")
assert len(missing) == 0, f"Found endpoints without operationIds: {missing}"
def test_all_endpoints_have_tags(api_spec: Dict[str, Any]):
"""
Test that all endpoints have tags
Args:
api_spec: Loaded OpenAPI spec
"""
missing = []
for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'tags' not in operation or not operation['tags']:
missing.append(f"{method.upper()} {path}")
assert len(missing) == 0, f"Found endpoints without tags: {missing}"

View File

@@ -0,0 +1,157 @@
"""
Utilities for working with OpenAPI schemas
"""
from typing import Any, Dict, List, Optional, Set, Tuple
def extract_required_parameters(
spec: Dict[str, Any],
path: str,
method: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Extract required parameters for a specific endpoint
Args:
spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post')
Returns:
Tuple of (path_params, query_params) containing required parameters
"""
method = method.lower()
path_params = []
query_params = []
# Handle path not found
if path not in spec['paths']:
return path_params, query_params
# Handle method not found
if method not in spec['paths'][path]:
return path_params, query_params
# Get parameters
params = spec['paths'][path][method].get('parameters', [])
for param in params:
if param.get('required', False):
if param.get('in') == 'path':
path_params.append(param)
elif param.get('in') == 'query':
query_params.append(param)
return path_params, query_params
def get_request_body_schema(
spec: Dict[str, Any],
path: str,
method: str
) -> Optional[Dict[str, Any]]:
"""
Get request body schema for a specific endpoint
Args:
spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post')
Returns:
Request body schema or None if not found
"""
method = method.lower()
# Handle path not found
if path not in spec['paths']:
return None
# Handle method not found
if method not in spec['paths'][path]:
return None
# Handle no request body
request_body = spec['paths'][path][method].get('requestBody', {})
if not request_body or 'content' not in request_body:
return None
# Get schema from first content type
content_types = request_body['content']
first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]:
return None
return content_types[first_content_type]['schema']
def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, Any]]:
"""
Extract all endpoints with a specific tag
Args:
spec: Parsed OpenAPI specification
tag: Tag to filter by
Returns:
List of endpoint details
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
if tag in operation.get('tags', []):
endpoints.append({
'path': path,
'method': method.lower(),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints
def get_all_tags(spec: Dict[str, Any]) -> Set[str]:
"""
Get all tags used in the API spec
Args:
spec: Parsed OpenAPI specification
Returns:
Set of tag names
"""
tags = set()
for path_item in spec['paths'].values():
for operation in path_item.values():
if isinstance(operation, dict) and 'tags' in operation:
tags.update(operation['tags'])
return tags
def get_schema_examples(spec: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract all examples from component schemas
Args:
spec: Parsed OpenAPI specification
Returns:
Dict mapping schema names to examples
"""
examples = {}
if 'components' not in spec or 'schemas' not in spec['components']:
return examples
for name, schema in spec['components']['schemas'].items():
if 'example' in schema:
examples[name] = schema['example']
return examples

View File

@@ -0,0 +1,178 @@
"""
Utilities for API response validation against OpenAPI spec
"""
import yaml
import jsonschema
from typing import Any, Dict, List, Optional, Union
def load_openapi_spec(spec_path: str) -> Dict[str, Any]:
"""
Load the OpenAPI specification from a YAML file
Args:
spec_path: Path to the OpenAPI specification file
Returns:
Dict containing the parsed OpenAPI spec
"""
with open(spec_path, 'r') as f:
return yaml.safe_load(f)
def get_endpoint_schema(
spec: Dict[str, Any],
path: str,
method: str,
status_code: str = '200'
) -> Optional[Dict[str, Any]]:
"""
Extract response schema for a specific endpoint from OpenAPI spec
Args:
spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post')
status_code: HTTP status code to get schema for
Returns:
Schema dict or None if not found
"""
method = method.lower()
# Handle path not found
if path not in spec['paths']:
return None
# Handle method not found
if method not in spec['paths'][path]:
return None
# Handle status code not found
responses = spec['paths'][path][method].get('responses', {})
if status_code not in responses:
return None
# Handle no content defined
if 'content' not in responses[status_code]:
return None
# Get schema from first content type
content_types = responses[status_code]['content']
first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]:
return None
return content_types[first_content_type]['schema']
def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[str, Any]:
"""
Resolve $ref references in a schema
Args:
schema: Schema that may contain references
spec: Full OpenAPI spec with component definitions
Returns:
Schema with references resolved
"""
if not isinstance(schema, dict):
return schema
result = {}
for key, value in schema.items():
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
# Handle reference
ref_path = value[2:].split('/')
ref_value = spec
for path_part in ref_path:
ref_value = ref_value.get(path_part, {})
# Recursively resolve any refs in the referenced schema
ref_value = resolve_schema_refs(ref_value, spec)
result.update(ref_value)
elif isinstance(value, dict):
# Recursively resolve refs in nested dictionaries
result[key] = resolve_schema_refs(value, spec)
elif isinstance(value, list):
# Recursively resolve refs in list items
result[key] = [
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
for item in value
]
else:
# Pass through other values
result[key] = value
return result
def validate_response(
response_data: Union[Dict[str, Any], List[Any]],
spec: Dict[str, Any],
path: str,
method: str,
status_code: str = '200'
) -> Dict[str, Any]:
"""
Validate a response against the OpenAPI schema
Args:
response_data: Response data to validate
spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post')
status_code: HTTP status code to validate against
Returns:
Dict with validation result containing:
- valid: bool indicating if validation passed
- errors: List of validation errors if any
"""
schema = get_endpoint_schema(spec, path, method, status_code)
if schema is None:
return {
'valid': False,
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
}
# Resolve any $ref in the schema
resolved_schema = resolve_schema_refs(schema, spec)
try:
jsonschema.validate(instance=response_data, schema=resolved_schema)
return {'valid': True, 'errors': []}
except jsonschema.exceptions.ValidationError as e:
return {'valid': False, 'errors': [str(e)]}
def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Extract all endpoints from an OpenAPI spec
Args:
spec: Parsed OpenAPI specification
Returns:
List of dicts with path, method, and tags for each endpoint
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
endpoints.append({
'path': path,
'method': method.lower(),
'tags': operation.get('tags', []),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints

View File

@@ -1,240 +0,0 @@
import torch
from unittest.mock import patch, MagicMock
# Mock nodes module to prevent CUDA initialization during import
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
with patch.dict('sys.modules', {'nodes': mock_nodes}):
from comfy_extras.nodes_images import ImageStitch
class TestImageStitch:
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
"""Helper to create test images with specific dimensions"""
return torch.rand(batch_size, height, width, channels)
def test_no_image2_passthrough(self):
"""Test that when image2 is None, image1 is returned unchanged"""
node = ImageStitch()
image1 = self.create_test_image()
result = node.stitch(image1, "right", True, 0, "white", image2=None)
assert len(result) == 1
assert torch.equal(result[0], image1)
def test_basic_horizontal_stitch_right(self):
"""Test basic horizontal stitching to the right"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
result = node.stitch(image1, "right", False, 0, "white", image2)
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
def test_basic_horizontal_stitch_left(self):
"""Test basic horizontal stitching to the left"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
result = node.stitch(image1, "left", False, 0, "white", image2)
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
def test_basic_vertical_stitch_down(self):
"""Test basic vertical stitching downward"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
result = node.stitch(image1, "down", False, 0, "white", image2)
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
def test_basic_vertical_stitch_up(self):
"""Test basic vertical stitching upward"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
result = node.stitch(image1, "up", False, 0, "white", image2)
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
def test_size_matching_horizontal(self):
"""Test size matching for horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=64)
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
result = node.stitch(image1, "right", True, 0, "white", image2)
# image2 should be resized to match image1's height (64) with preserved aspect ratio
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
assert result[0].shape == (1, 64, expected_width, 3)
def test_size_matching_vertical(self):
"""Test size matching for vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=64)
image2 = self.create_test_image(height=32, width=32)
result = node.stitch(image1, "down", True, 0, "white", image2)
# image2 should be resized to match image1's width (64) with preserved aspect ratio
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
assert result[0].shape == (1, expected_height, 64, 3)
def test_padding_for_mismatched_heights_horizontal(self):
"""Test padding when heights don't match in horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=32)
image2 = self.create_test_image(height=48, width=24) # Shorter height
result = node.stitch(image1, "right", False, 0, "white", image2)
# Both images should be padded to height 64
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
def test_padding_for_mismatched_widths_vertical(self):
"""Test padding when widths don't match in vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=64)
image2 = self.create_test_image(height=24, width=48) # Narrower width
result = node.stitch(image1, "down", False, 0, "white", image2)
# Both images should be padded to width 64
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
def test_spacing_horizontal(self):
"""Test spacing addition in horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
spacing_width = 16
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
# Expected width: 32 + 16 (spacing) + 24 = 72
assert result[0].shape == (1, 32, 72, 3)
def test_spacing_vertical(self):
"""Test spacing addition in vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
spacing_width = 16
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
# Expected height: 32 + 16 (spacing) + 24 = 72
assert result[0].shape == (1, 72, 32, 3)
def test_spacing_color_values(self):
"""Test that spacing colors are applied correctly"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
# Test white spacing
result_white = node.stitch(image1, "right", False, 16, "white", image2)
# Check that spacing region contains white values (close to 1.0)
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
assert torch.all(spacing_region >= 0.9) # Should be close to white
# Test black spacing
result_black = node.stitch(image1, "right", False, 16, "black", image2)
spacing_region = result_black[0][:, :, 32:48, :]
assert torch.all(spacing_region <= 0.1) # Should be close to black
def test_odd_spacing_width_made_even(self):
"""Test that odd spacing widths are made even"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
# Use odd spacing width
result = node.stitch(image1, "right", False, 15, "white", image2)
# Should be made even (16), so total width = 32 + 16 + 32 = 80
assert result[0].shape == (1, 32, 80, 3)
def test_batch_size_matching(self):
"""Test that different batch sizes are handled correctly"""
node = ImageStitch()
image1 = self.create_test_image(batch_size=2, height=32, width=32)
image2 = self.create_test_image(batch_size=1, height=32, width=32)
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should match larger batch size
assert result[0].shape == (2, 32, 64, 3)
def test_channel_matching_rgb_to_rgba(self):
"""Test that channel differences are handled (RGB + alpha)"""
node = ImageStitch()
image1 = self.create_test_image(channels=3) # RGB
image2 = self.create_test_image(channels=4) # RGBA
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should have 4 channels (RGBA)
assert result[0].shape[-1] == 4
def test_channel_matching_rgba_to_rgb(self):
"""Test that channel differences are handled (RGBA + RGB)"""
node = ImageStitch()
image1 = self.create_test_image(channels=4) # RGBA
image2 = self.create_test_image(channels=3) # RGB
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should have 4 channels (RGBA)
assert result[0].shape[-1] == 4
def test_all_color_options(self):
"""Test all available color options"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
colors = ["white", "black", "red", "green", "blue"]
for color in colors:
result = node.stitch(image1, "right", False, 16, color, image2)
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
def test_all_directions(self):
"""Test all direction options"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
directions = ["right", "left", "up", "down"]
for direction in directions:
result = node.stitch(image1, direction, False, 0, "white", image2)
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
def test_batch_size_channel_spacing_integration(self):
"""Test integration of batch matching, channel matching, size matching, and spacings"""
node = ImageStitch()
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
result = node.stitch(image1, "right", True, 8, "red", image2)
# Should handle: batch matching, size matching, channel matching, spacing
assert result[0].shape[0] == 2 # Batch size matched
assert result[0].shape[-1] == 4 # Channels matched to max
assert result[0].shape[1] == 64 # Height from image1 (size matching)
# Width should be: 48 + 8 (spacing) + resized_image2_width
expected_image2_width = int(64 * (32/32)) # Resized to height 64
expected_total_width = 48 + 8 + expected_image2_width
assert result[0].shape[2] == expected_total_width