Compare commits
14 Commits
openapi-sp
...
v0.3.36
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad3bd8aa49 | ||
|
|
5a87757ef9 | ||
|
|
464aece92b | ||
|
|
0b50d4c0db | ||
|
|
30b2eb8a93 | ||
|
|
f85c08df06 | ||
|
|
4202e956a0 | ||
|
|
b838c36720 | ||
|
|
fc39184ea9 | ||
|
|
ded60c33a0 | ||
|
|
8bb858e4d3 | ||
|
|
57893c843f | ||
|
|
65da29aaa9 | ||
|
|
10024a38ea |
49
.github/workflows/openapi-validation.yml
vendored
49
.github/workflows/openapi-validation.yml
vendored
@@ -1,49 +0,0 @@
|
|||||||
name: Validate OpenAPI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ master ]
|
|
||||||
paths:
|
|
||||||
- 'openapi.yaml'
|
|
||||||
pull_request:
|
|
||||||
branches: [ master ]
|
|
||||||
paths:
|
|
||||||
- 'openapi.yaml'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
openapi-check:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
# Service containers to run with `runner-job`
|
|
||||||
services:
|
|
||||||
# Label used to access the service container
|
|
||||||
swagger-editor:
|
|
||||||
# Docker Hub image
|
|
||||||
image: swaggerapi/swagger-editor
|
|
||||||
ports:
|
|
||||||
# Maps port 8080 on service container to the host 80
|
|
||||||
- 80:8080
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Validate OpenAPI definition
|
|
||||||
uses: swaggerexpert/swagger-editor-validate@v1
|
|
||||||
with:
|
|
||||||
definition-file: openapi.yaml
|
|
||||||
swagger-editor-url: http://localhost/
|
|
||||||
default-timeout: 20000
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Install test dependencies
|
|
||||||
run: |
|
|
||||||
pip install -r tests-api/requirements.txt
|
|
||||||
|
|
||||||
- name: Run OpenAPI spec validation tests
|
|
||||||
run: |
|
|
||||||
pytest tests-api/test_spec_validation.py -v
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,5 +21,6 @@ venv/
|
|||||||
*.log
|
*.log
|
||||||
web_custom_versions/
|
web_custom_versions/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
openapi.yaml
|
||||||
filtered-openapi.yaml
|
filtered-openapi.yaml
|
||||||
uv.lock
|
uv.lock
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
|
|||||||
|
|
||||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||||
|
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||||
|
|
||||||
class LatentPreviewMethod(enum.Enum):
|
class LatentPreviewMethod(enum.Enum):
|
||||||
NoPreviews = "none"
|
NoPreviews = "none"
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ class Chroma(nn.Module):
|
|||||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
||||||
|
|
||||||
# get all modulation index
|
# get all modulation index
|
||||||
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
|
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
|
||||||
# we need to broadcast the modulation index here so each batch has all of the index
|
# we need to broadcast the modulation index here so each batch has all of the index
|
||||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
||||||
# and we need to broadcast timestep and guidance along too
|
# and we need to broadcast timestep and guidance along too
|
||||||
|
|||||||
@@ -20,8 +20,11 @@ if model_management.xformers_enabled():
|
|||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError as e:
|
||||||
|
if e.name == "sageattention":
|
||||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
if model_management.flash_attention_enabled():
|
if model_management.flash_attention_enabled():
|
||||||
|
|||||||
@@ -635,7 +635,7 @@ class VaceWanModel(WanModel):
|
|||||||
t,
|
t,
|
||||||
context,
|
context,
|
||||||
vace_context,
|
vace_context,
|
||||||
vace_strength=1.0,
|
vace_strength,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
@@ -661,8 +661,11 @@ class VaceWanModel(WanModel):
|
|||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
context_img_len = clip_fea.shape[-2]
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
orig_shape = list(vace_context.shape)
|
||||||
|
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
|
||||||
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||||
c = c.flatten(2).transpose(1, 2)
|
c = c.flatten(2).transpose(1, 2)
|
||||||
|
c = list(c.split(orig_shape[0], dim=0))
|
||||||
|
|
||||||
# arguments
|
# arguments
|
||||||
x_orig = x
|
x_orig = x
|
||||||
@@ -682,8 +685,9 @@ class VaceWanModel(WanModel):
|
|||||||
|
|
||||||
ii = self.vace_layers_mapping.get(i, None)
|
ii = self.vace_layers_mapping.get(i, None)
|
||||||
if ii is not None:
|
if ii is not None:
|
||||||
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
for iii in range(len(c)):
|
||||||
x += c_skip * vace_strength
|
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
x += c_skip * vace_strength[iii]
|
||||||
del c_skip
|
del c_skip
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
|||||||
@@ -1062,20 +1062,25 @@ class WAN21_Vace(WAN21):
|
|||||||
vace_frames = kwargs.get("vace_frames", None)
|
vace_frames = kwargs.get("vace_frames", None)
|
||||||
if vace_frames is None:
|
if vace_frames is None:
|
||||||
noise_shape[1] = 32
|
noise_shape[1] = 32
|
||||||
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
|
||||||
|
|
||||||
for i in range(0, vace_frames.shape[1], 16):
|
|
||||||
vace_frames = vace_frames.clone()
|
|
||||||
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
|
||||||
|
|
||||||
mask = kwargs.get("vace_mask", None)
|
mask = kwargs.get("vace_mask", None)
|
||||||
if mask is None:
|
if mask is None:
|
||||||
noise_shape[1] = 64
|
noise_shape[1] = 64
|
||||||
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
|
||||||
|
|
||||||
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
vace_frames_out = []
|
||||||
|
for j in range(len(vace_frames)):
|
||||||
|
vf = vace_frames[j].clone()
|
||||||
|
for i in range(0, vf.shape[1], 16):
|
||||||
|
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
||||||
|
vf = torch.cat([vf, mask[j]], dim=1)
|
||||||
|
vace_frames_out.append(vf)
|
||||||
|
|
||||||
vace_strength = kwargs.get("vace_strength", 1.0)
|
vace_frames = torch.stack(vace_frames_out, dim=1)
|
||||||
|
out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
|
||||||
|
|
||||||
|
vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
|
||||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@@ -1257,6 +1257,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_fp8_compute(device=None):
|
def supports_fp8_compute(device=None):
|
||||||
|
if args.supports_fp8_compute:
|
||||||
|
return True
|
||||||
|
|
||||||
if not is_nvidia():
|
if not is_nvidia():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
5
comfy_api/torch_helpers/__init__.py
Normal file
5
comfy_api/torch_helpers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .torch_compile import set_torch_compile_wrapper
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"set_torch_compile_wrapper",
|
||||||
|
]
|
||||||
69
comfy_api/torch_helpers/torch_compile.py
Normal file
69
comfy_api/torch_helpers/torch_compile.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
from comfy.patcher_extension import WrappersMP
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.patcher_extension import WrapperExecutor
|
||||||
|
|
||||||
|
|
||||||
|
COMPILE_KEY = "torch.compile"
|
||||||
|
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
|
||||||
|
|
||||||
|
|
||||||
|
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
|
||||||
|
'''
|
||||||
|
Create a wrapper that will refer to the compiled_diffusion_model.
|
||||||
|
'''
|
||||||
|
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
orig_modules = {}
|
||||||
|
for key, value in compiled_module_dict.items():
|
||||||
|
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
||||||
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
for key, value in orig_modules.items():
|
||||||
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||||
|
return apply_torch_compile_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
|
||||||
|
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
|
||||||
|
keys: list[str]=["diffusion_model"], *args, **kwargs):
|
||||||
|
'''
|
||||||
|
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
|
||||||
|
|
||||||
|
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
|
||||||
|
When a list of keys is provided, it will perform torch.compile on only the selected modules.
|
||||||
|
'''
|
||||||
|
# clear out any other torch.compile wrappers
|
||||||
|
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
|
||||||
|
# if no keys, default to 'diffusion_model'
|
||||||
|
if not keys:
|
||||||
|
keys = ["diffusion_model"]
|
||||||
|
# create kwargs dict that can be referenced later
|
||||||
|
compile_kwargs = {
|
||||||
|
"backend": backend,
|
||||||
|
"options": options,
|
||||||
|
"mode": mode,
|
||||||
|
"fullgraph": fullgraph,
|
||||||
|
"dynamic": dynamic,
|
||||||
|
}
|
||||||
|
# get a dict of compiled keys
|
||||||
|
compiled_modules = {}
|
||||||
|
for key in keys:
|
||||||
|
compiled_modules[key] = torch.compile(
|
||||||
|
model=model.get_model_object(key),
|
||||||
|
**compile_kwargs,
|
||||||
|
)
|
||||||
|
# add torch.compile wrapper
|
||||||
|
wrapper_func = apply_torch_compile_factory(
|
||||||
|
compiled_module_dict=compiled_modules,
|
||||||
|
)
|
||||||
|
# store wrapper to run on BaseModel's apply_model function
|
||||||
|
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
|
||||||
|
# keep compile kwargs for reference
|
||||||
|
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs
|
||||||
@@ -16,7 +16,7 @@ class Load3D():
|
|||||||
|
|
||||||
os.makedirs(input_dir, exist_ok=True)
|
os.makedirs(input_dir, exist_ok=True)
|
||||||
|
|
||||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))]
|
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
|
||||||
|
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": (sorted(files), {"file_upload": True}),
|
"model_file": (sorted(files), {"file_upload": True}),
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileModel:
|
class TorchCompileModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -14,7 +15,7 @@ class TorchCompileModel:
|
|||||||
|
|
||||||
def patch(self, model, backend):
|
def patch(self, model, backend):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
set_torch_compile_wrapper(model=m, backend=backend)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@@ -268,8 +268,9 @@ class WanVaceToVideo:
|
|||||||
trim_latent = reference_image.shape[2]
|
trim_latent = reference_image.shape[2]
|
||||||
|
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
|
||||||
|
|
||||||
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.34"
|
__version__ = "0.3.36"
|
||||||
|
|||||||
@@ -909,7 +909,6 @@ class PromptQueue:
|
|||||||
self.currently_running = {}
|
self.currently_running = {}
|
||||||
self.history = {}
|
self.history = {}
|
||||||
self.flags = {}
|
self.flags = {}
|
||||||
server.prompt_queue = self
|
|
||||||
|
|
||||||
def put(self, item):
|
def put(self, item):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
@@ -954,6 +953,7 @@ class PromptQueue:
|
|||||||
self.history[prompt[1]].update(history_result)
|
self.history[prompt[1]].update(history_result)
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
|
||||||
|
# Note: slow
|
||||||
def get_current_queue(self):
|
def get_current_queue(self):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
out = []
|
out = []
|
||||||
@@ -961,6 +961,13 @@ class PromptQueue:
|
|||||||
out += [x]
|
out += [x]
|
||||||
return (out, copy.deepcopy(self.queue))
|
return (out, copy.deepcopy(self.queue))
|
||||||
|
|
||||||
|
# read-safe as long as queue items are immutable
|
||||||
|
def get_current_queue_volatile(self):
|
||||||
|
with self.mutex:
|
||||||
|
running = [x for x in self.currently_running.values()]
|
||||||
|
queued = copy.copy(self.queue)
|
||||||
|
return (running, queued)
|
||||||
|
|
||||||
def get_tasks_remaining(self):
|
def get_tasks_remaining(self):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
return len(self.queue) + len(self.currently_running)
|
return len(self.queue) + len(self.currently_running)
|
||||||
|
|||||||
3
main.py
3
main.py
@@ -260,7 +260,6 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
asyncio_loop = asyncio.new_event_loop()
|
asyncio_loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(asyncio_loop)
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
prompt_server = server.PromptServer(asyncio_loop)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
q = execution.PromptQueue(prompt_server)
|
|
||||||
|
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||||
@@ -271,7 +270,7 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|||||||
@@ -5,12 +5,18 @@ from comfy.cli_args import args
|
|||||||
|
|
||||||
from PIL import ImageFile, UnidentifiedImageError
|
from PIL import ImageFile, UnidentifiedImageError
|
||||||
|
|
||||||
def conditioning_set_values(conditioning, values={}):
|
def conditioning_set_values(conditioning, values={}, append=False):
|
||||||
c = []
|
c = []
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
for k in values:
|
for k in values:
|
||||||
n[1][k] = values[k]
|
val = values[k]
|
||||||
|
if append:
|
||||||
|
old_val = n[1].get(k, None)
|
||||||
|
if old_val is not None:
|
||||||
|
val = old_val + val
|
||||||
|
|
||||||
|
n[1][k] = val
|
||||||
c.append(n)
|
c.append(n)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|||||||
11
nodes.py
11
nodes.py
@@ -1103,16 +1103,7 @@ class unCLIPConditioning:
|
|||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (conditioning, )
|
return (conditioning, )
|
||||||
|
|
||||||
c = []
|
c = node_helpers.conditioning_set_values(conditioning, {"unclip_conditioning": [{"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}]}, append=True)
|
||||||
for t in conditioning:
|
|
||||||
o = t[1].copy()
|
|
||||||
x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
|
|
||||||
if "unclip_conditioning" in o:
|
|
||||||
o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
|
|
||||||
else:
|
|
||||||
o["unclip_conditioning"] = [x]
|
|
||||||
n = [t[0], o]
|
|
||||||
c.append(n)
|
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class GLIGENLoader:
|
class GLIGENLoader:
|
||||||
|
|||||||
904
openapi.yaml
904
openapi.yaml
@@ -1,904 +0,0 @@
|
|||||||
openapi: 3.0.3
|
|
||||||
info:
|
|
||||||
title: ComfyUI API
|
|
||||||
description: |
|
|
||||||
API for ComfyUI - A powerful and modular UI for Stable Diffusion.
|
|
||||||
|
|
||||||
This API allows you to interact with ComfyUI programmatically, including:
|
|
||||||
- Submitting workflows for execution
|
|
||||||
- Managing the execution queue
|
|
||||||
- Retrieving generated images
|
|
||||||
- Managing models
|
|
||||||
- Retrieving node information
|
|
||||||
version: 1.0.0
|
|
||||||
license:
|
|
||||||
name: GNU General Public License v3.0
|
|
||||||
url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE
|
|
||||||
|
|
||||||
servers:
|
|
||||||
- url: /
|
|
||||||
description: Default ComfyUI server
|
|
||||||
|
|
||||||
tags:
|
|
||||||
- name: workflow
|
|
||||||
description: Workflow execution and management
|
|
||||||
- name: queue
|
|
||||||
description: Queue management
|
|
||||||
- name: image
|
|
||||||
description: Image handling
|
|
||||||
- name: node
|
|
||||||
description: Node information
|
|
||||||
- name: model
|
|
||||||
description: Model management
|
|
||||||
- name: system
|
|
||||||
description: System information
|
|
||||||
- name: internal
|
|
||||||
description: Internal API routes
|
|
||||||
|
|
||||||
paths:
|
|
||||||
/prompt:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- workflow
|
|
||||||
summary: Get information about current prompt execution
|
|
||||||
description: Returns information about the current prompt in the execution queue
|
|
||||||
operationId: getPromptInfo
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/PromptInfo'
|
|
||||||
post:
|
|
||||||
tags:
|
|
||||||
- workflow
|
|
||||||
summary: Submit a workflow for execution
|
|
||||||
description: |
|
|
||||||
Submit a workflow to be executed by the backend.
|
|
||||||
The workflow is a JSON object describing the nodes and their connections.
|
|
||||||
operationId: executePrompt
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/PromptRequest'
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success - Prompt accepted
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/PromptResponse'
|
|
||||||
'400':
|
|
||||||
description: Invalid prompt
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/ErrorResponse'
|
|
||||||
|
|
||||||
/queue:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- queue
|
|
||||||
summary: Get queue information
|
|
||||||
description: Returns information about running and pending items in the queue
|
|
||||||
operationId: getQueueInfo
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/QueueInfo'
|
|
||||||
post:
|
|
||||||
tags:
|
|
||||||
- queue
|
|
||||||
summary: Manage queue
|
|
||||||
description: Clear the queue or delete specific items
|
|
||||||
operationId: manageQueue
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
clear:
|
|
||||||
type: boolean
|
|
||||||
description: If true, clears the entire queue
|
|
||||||
delete:
|
|
||||||
type: array
|
|
||||||
description: Array of prompt IDs to delete from the queue
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
format: uuid
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
|
|
||||||
/interrupt:
|
|
||||||
post:
|
|
||||||
tags:
|
|
||||||
- workflow
|
|
||||||
summary: Interrupt the current execution
|
|
||||||
description: Interrupts the currently running workflow execution
|
|
||||||
operationId: interruptExecution
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
|
|
||||||
/free:
|
|
||||||
post:
|
|
||||||
tags:
|
|
||||||
- system
|
|
||||||
summary: Free resources
|
|
||||||
description: Unload models and/or free memory
|
|
||||||
operationId: freeResources
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
unload_models:
|
|
||||||
type: boolean
|
|
||||||
description: If true, unloads models from memory
|
|
||||||
free_memory:
|
|
||||||
type: boolean
|
|
||||||
description: If true, frees GPU memory
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
|
|
||||||
/history:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- workflow
|
|
||||||
summary: Get execution history
|
|
||||||
description: Returns the history of executed workflows
|
|
||||||
operationId: getHistory
|
|
||||||
parameters:
|
|
||||||
- name: max_items
|
|
||||||
in: query
|
|
||||||
description: Maximum number of history items to return
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: integer
|
|
||||||
format: int32
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/HistoryItem'
|
|
||||||
post:
|
|
||||||
tags:
|
|
||||||
- workflow
|
|
||||||
summary: Manage history
|
|
||||||
description: Clear history or delete specific items
|
|
||||||
operationId: manageHistory
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
clear:
|
|
||||||
type: boolean
|
|
||||||
description: If true, clears the entire history
|
|
||||||
delete:
|
|
||||||
type: array
|
|
||||||
description: Array of prompt IDs to delete from history
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
format: uuid
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
|
|
||||||
/history/{prompt_id}:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- workflow
|
|
||||||
summary: Get specific history item
|
|
||||||
description: Returns a specific history item by ID
|
|
||||||
operationId: getHistoryItem
|
|
||||||
parameters:
|
|
||||||
- name: prompt_id
|
|
||||||
in: path
|
|
||||||
description: ID of the prompt to retrieve
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
format: uuid
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/HistoryItem'
|
|
||||||
|
|
||||||
/object_info:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- node
|
|
||||||
summary: Get all node information
|
|
||||||
description: Returns information about all available nodes
|
|
||||||
operationId: getNodeInfo
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
$ref: '#/components/schemas/NodeInfo'
|
|
||||||
|
|
||||||
/object_info/{node_class}:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- node
|
|
||||||
summary: Get specific node information
|
|
||||||
description: Returns information about a specific node class
|
|
||||||
operationId: getNodeClassInfo
|
|
||||||
parameters:
|
|
||||||
- name: node_class
|
|
||||||
in: path
|
|
||||||
description: Name of the node class
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
$ref: '#/components/schemas/NodeInfo'
|
|
||||||
|
|
||||||
/upload/image:
|
|
||||||
post:
|
|
||||||
tags:
|
|
||||||
- image
|
|
||||||
summary: Upload an image
|
|
||||||
description: Uploads an image to the server
|
|
||||||
operationId: uploadImage
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
multipart/form-data:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
image:
|
|
||||||
type: string
|
|
||||||
format: binary
|
|
||||||
description: The image file to upload
|
|
||||||
overwrite:
|
|
||||||
type: string
|
|
||||||
description: Whether to overwrite if file exists (true/false)
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
enum: [input, temp, output]
|
|
||||||
description: Type of directory to store the image in
|
|
||||||
subfolder:
|
|
||||||
type: string
|
|
||||||
description: Subfolder to store the image in
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
description: Filename of the uploaded image
|
|
||||||
subfolder:
|
|
||||||
type: string
|
|
||||||
description: Subfolder the image was stored in
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
description: Type of directory the image was stored in
|
|
||||||
'400':
|
|
||||||
description: Bad request
|
|
||||||
|
|
||||||
/upload/mask:
|
|
||||||
post:
|
|
||||||
tags:
|
|
||||||
- image
|
|
||||||
summary: Upload a mask for an image
|
|
||||||
description: Uploads a mask image and applies it to a referenced original image
|
|
||||||
operationId: uploadMask
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
multipart/form-data:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
image:
|
|
||||||
type: string
|
|
||||||
format: binary
|
|
||||||
description: The mask image file to upload
|
|
||||||
original_ref:
|
|
||||||
type: string
|
|
||||||
description: JSON string containing reference to the original image
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
description: Filename of the uploaded mask
|
|
||||||
subfolder:
|
|
||||||
type: string
|
|
||||||
description: Subfolder the mask was stored in
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
description: Type of directory the mask was stored in
|
|
||||||
'400':
|
|
||||||
description: Bad request
|
|
||||||
|
|
||||||
/view:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- image
|
|
||||||
summary: View an image
|
|
||||||
description: Retrieves an image from the server
|
|
||||||
operationId: viewImage
|
|
||||||
parameters:
|
|
||||||
- name: filename
|
|
||||||
in: query
|
|
||||||
description: Name of the file to retrieve
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
- name: type
|
|
||||||
in: query
|
|
||||||
description: Type of directory to retrieve from
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
enum: [input, temp, output]
|
|
||||||
default: output
|
|
||||||
- name: subfolder
|
|
||||||
in: query
|
|
||||||
description: Subfolder to retrieve from
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
- name: preview
|
|
||||||
in: query
|
|
||||||
description: Preview options (format;quality)
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
- name: channel
|
|
||||||
in: query
|
|
||||||
description: Channel to retrieve (rgb, a, rgba)
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
enum: [rgb, a, rgba]
|
|
||||||
default: rgba
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
image/*:
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
format: binary
|
|
||||||
'400':
|
|
||||||
description: Bad request
|
|
||||||
'404':
|
|
||||||
description: File not found
|
|
||||||
|
|
||||||
/view_metadata/{folder_name}:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- model
|
|
||||||
summary: View model metadata
|
|
||||||
description: Retrieves metadata from a safetensors file
|
|
||||||
operationId: viewModelMetadata
|
|
||||||
parameters:
|
|
||||||
- name: folder_name
|
|
||||||
in: path
|
|
||||||
description: Name of the model folder
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
- name: filename
|
|
||||||
in: query
|
|
||||||
description: Name of the safetensors file
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
'404':
|
|
||||||
description: File not found
|
|
||||||
|
|
||||||
/models:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- model
|
|
||||||
summary: Get model types
|
|
||||||
description: Returns a list of available model types
|
|
||||||
operationId: getModelTypes
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
|
|
||||||
/models/{folder}:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- model
|
|
||||||
summary: Get models of a specific type
|
|
||||||
description: Returns a list of available models of a specific type
|
|
||||||
operationId: getModels
|
|
||||||
parameters:
|
|
||||||
- name: folder
|
|
||||||
in: path
|
|
||||||
description: Model type folder
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
'404':
|
|
||||||
description: Folder not found
|
|
||||||
|
|
||||||
/embeddings:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- model
|
|
||||||
summary: Get embeddings
|
|
||||||
description: Returns a list of available embeddings
|
|
||||||
operationId: getEmbeddings
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
|
|
||||||
/extensions:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- system
|
|
||||||
summary: Get extensions
|
|
||||||
description: Returns a list of available extensions
|
|
||||||
operationId: getExtensions
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
|
|
||||||
/system_stats:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- system
|
|
||||||
summary: Get system statistics
|
|
||||||
description: Returns system information including RAM, VRAM, and ComfyUI version
|
|
||||||
operationId: getSystemStats
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/SystemStats'
|
|
||||||
|
|
||||||
/ws:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- workflow
|
|
||||||
summary: WebSocket connection
|
|
||||||
description: |
|
|
||||||
Establishes a WebSocket connection for real-time communication.
|
|
||||||
This endpoint is used for receiving progress updates, status changes, and results from workflow executions.
|
|
||||||
operationId: webSocketConnect
|
|
||||||
parameters:
|
|
||||||
- name: clientId
|
|
||||||
in: query
|
|
||||||
description: Optional client ID for reconnection
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
responses:
|
|
||||||
'101':
|
|
||||||
description: Switching Protocols to WebSocket
|
|
||||||
|
|
||||||
/internal/logs:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- internal
|
|
||||||
summary: Get logs
|
|
||||||
description: Returns system logs as a single string
|
|
||||||
operationId: getLogs
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
|
|
||||||
/internal/logs/raw:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- internal
|
|
||||||
summary: Get raw logs
|
|
||||||
description: Returns raw system logs with terminal size information
|
|
||||||
operationId: getRawLogs
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
entries:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
t:
|
|
||||||
type: string
|
|
||||||
description: Timestamp
|
|
||||||
m:
|
|
||||||
type: string
|
|
||||||
description: Message
|
|
||||||
size:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
cols:
|
|
||||||
type: integer
|
|
||||||
description: Terminal columns
|
|
||||||
rows:
|
|
||||||
type: integer
|
|
||||||
description: Terminal rows
|
|
||||||
|
|
||||||
/internal/logs/subscribe:
|
|
||||||
patch:
|
|
||||||
tags:
|
|
||||||
- internal
|
|
||||||
summary: Subscribe to logs
|
|
||||||
description: Subscribe or unsubscribe to log updates
|
|
||||||
operationId: subscribeToLogs
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
clientId:
|
|
||||||
type: string
|
|
||||||
description: Client ID
|
|
||||||
enabled:
|
|
||||||
type: boolean
|
|
||||||
description: Whether to enable or disable subscription
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
|
|
||||||
/internal/folder_paths:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- internal
|
|
||||||
summary: Get folder paths
|
|
||||||
description: Returns a map of folder names to their paths
|
|
||||||
operationId: getFolderPaths
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
type: string
|
|
||||||
|
|
||||||
/internal/files/{directory_type}:
|
|
||||||
get:
|
|
||||||
tags:
|
|
||||||
- internal
|
|
||||||
summary: Get files
|
|
||||||
description: Returns a list of files in a specific directory type
|
|
||||||
operationId: getFiles
|
|
||||||
parameters:
|
|
||||||
- name: directory_type
|
|
||||||
in: path
|
|
||||||
description: Type of directory (output, input, temp)
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
enum: [output, input, temp]
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Success
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
'400':
|
|
||||||
description: Invalid directory type
|
|
||||||
|
|
||||||
components:
|
|
||||||
schemas:
|
|
||||||
PromptRequest:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- prompt
|
|
||||||
properties:
|
|
||||||
prompt:
|
|
||||||
type: object
|
|
||||||
description: The workflow graph to execute
|
|
||||||
additionalProperties: true
|
|
||||||
number:
|
|
||||||
type: number
|
|
||||||
description: Priority number for the queue (lower numbers have higher priority)
|
|
||||||
front:
|
|
||||||
type: boolean
|
|
||||||
description: If true, adds the prompt to the front of the queue
|
|
||||||
extra_data:
|
|
||||||
type: object
|
|
||||||
description: Extra data to be associated with the prompt
|
|
||||||
additionalProperties: true
|
|
||||||
client_id:
|
|
||||||
type: string
|
|
||||||
description: Client ID for attribution of the prompt
|
|
||||||
|
|
||||||
PromptResponse:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
prompt_id:
|
|
||||||
type: string
|
|
||||||
format: uuid
|
|
||||||
description: Unique identifier for the prompt execution
|
|
||||||
number:
|
|
||||||
type: number
|
|
||||||
description: Priority number in the queue
|
|
||||||
node_errors:
|
|
||||||
type: object
|
|
||||||
description: Any errors in the nodes of the prompt
|
|
||||||
additionalProperties: true
|
|
||||||
|
|
||||||
ErrorResponse:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
error:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
description: Error type
|
|
||||||
message:
|
|
||||||
type: string
|
|
||||||
description: Error message
|
|
||||||
details:
|
|
||||||
type: string
|
|
||||||
description: Detailed error information
|
|
||||||
extra_info:
|
|
||||||
type: object
|
|
||||||
description: Additional error information
|
|
||||||
additionalProperties: true
|
|
||||||
node_errors:
|
|
||||||
type: object
|
|
||||||
description: Node-specific errors
|
|
||||||
additionalProperties: true
|
|
||||||
|
|
||||||
PromptInfo:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
exec_info:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
queue_remaining:
|
|
||||||
type: integer
|
|
||||||
description: Number of items remaining in the queue
|
|
||||||
|
|
||||||
QueueInfo:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
queue_running:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
description: Currently running items
|
|
||||||
additionalProperties: true
|
|
||||||
queue_pending:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
description: Pending items in the queue
|
|
||||||
additionalProperties: true
|
|
||||||
|
|
||||||
HistoryItem:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
prompt_id:
|
|
||||||
type: string
|
|
||||||
format: uuid
|
|
||||||
description: Unique identifier for the prompt
|
|
||||||
prompt:
|
|
||||||
type: object
|
|
||||||
description: The workflow graph that was executed
|
|
||||||
additionalProperties: true
|
|
||||||
extra_data:
|
|
||||||
type: object
|
|
||||||
description: Additional data associated with the execution
|
|
||||||
additionalProperties: true
|
|
||||||
outputs:
|
|
||||||
type: object
|
|
||||||
description: Output data from the execution
|
|
||||||
additionalProperties: true
|
|
||||||
|
|
||||||
NodeInfo:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
input:
|
|
||||||
type: object
|
|
||||||
description: Input specifications for the node
|
|
||||||
additionalProperties: true
|
|
||||||
input_order:
|
|
||||||
type: object
|
|
||||||
description: Order of inputs for display
|
|
||||||
additionalProperties:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
output:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
description: Output types of the node
|
|
||||||
output_is_list:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: boolean
|
|
||||||
description: Whether each output is a list
|
|
||||||
output_name:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
description: Names of the outputs
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
description: Internal name of the node
|
|
||||||
display_name:
|
|
||||||
type: string
|
|
||||||
description: Display name of the node
|
|
||||||
description:
|
|
||||||
type: string
|
|
||||||
description: Description of the node
|
|
||||||
python_module:
|
|
||||||
type: string
|
|
||||||
description: Python module implementing the node
|
|
||||||
category:
|
|
||||||
type: string
|
|
||||||
description: Category of the node
|
|
||||||
output_node:
|
|
||||||
type: boolean
|
|
||||||
description: Whether this is an output node
|
|
||||||
output_tooltips:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
description: Tooltips for outputs
|
|
||||||
deprecated:
|
|
||||||
type: boolean
|
|
||||||
description: Whether the node is deprecated
|
|
||||||
experimental:
|
|
||||||
type: boolean
|
|
||||||
description: Whether the node is experimental
|
|
||||||
api_node:
|
|
||||||
type: boolean
|
|
||||||
description: Whether this is an API node
|
|
||||||
|
|
||||||
SystemStats:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
system:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
os:
|
|
||||||
type: string
|
|
||||||
description: Operating system
|
|
||||||
ram_total:
|
|
||||||
type: number
|
|
||||||
description: Total system RAM in bytes
|
|
||||||
ram_free:
|
|
||||||
type: number
|
|
||||||
description: Free system RAM in bytes
|
|
||||||
comfyui_version:
|
|
||||||
type: string
|
|
||||||
description: ComfyUI version
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
description: Python version
|
|
||||||
pytorch_version:
|
|
||||||
type: string
|
|
||||||
description: PyTorch version
|
|
||||||
embedded_python:
|
|
||||||
type: boolean
|
|
||||||
description: Whether using embedded Python
|
|
||||||
argv:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
description: Command line arguments
|
|
||||||
devices:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
description: Device name
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
description: Device type
|
|
||||||
index:
|
|
||||||
type: integer
|
|
||||||
description: Device index
|
|
||||||
vram_total:
|
|
||||||
type: number
|
|
||||||
description: Total VRAM in bytes
|
|
||||||
vram_free:
|
|
||||||
type: number
|
|
||||||
description: Free VRAM in bytes
|
|
||||||
torch_vram_total:
|
|
||||||
type: number
|
|
||||||
description: Total VRAM as reported by PyTorch
|
|
||||||
torch_vram_free:
|
|
||||||
type: number
|
|
||||||
description: Free VRAM as reported by PyTorch
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.34"
|
version = "0.3.36"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.19.9
|
comfyui-frontend-package==1.20.5
|
||||||
comfyui-workflow-templates==0.1.14
|
comfyui-workflow-templates==0.1.18
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import comfy.model_management
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from comfyui_version import __version__
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
@@ -159,7 +160,7 @@ class PromptServer():
|
|||||||
self.custom_node_manager = CustomNodeManager()
|
self.custom_node_manager = CustomNodeManager()
|
||||||
self.internal_routes = InternalRoutes(self)
|
self.internal_routes = InternalRoutes(self)
|
||||||
self.supports = ["custom_nodes_from_web"]
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = None
|
self.prompt_queue = execution.PromptQueue(self)
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
self.client_session:Optional[aiohttp.ClientSession] = None
|
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||||
@@ -226,7 +227,7 @@ class PromptServer():
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@routes.get("/embeddings")
|
@routes.get("/embeddings")
|
||||||
def get_embeddings(self):
|
def get_embeddings(request):
|
||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
@@ -282,7 +283,6 @@ class PromptServer():
|
|||||||
a.update(f.read())
|
a.update(f.read())
|
||||||
b.update(image.file.read())
|
b.update(image.file.read())
|
||||||
image.file.seek(0)
|
image.file.seek(0)
|
||||||
f.close()
|
|
||||||
return a.hexdigest() == b.hexdigest()
|
return a.hexdigest() == b.hexdigest()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -621,7 +621,7 @@ class PromptServer():
|
|||||||
@routes.get("/queue")
|
@routes.get("/queue")
|
||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
queue_info = {}
|
queue_info = {}
|
||||||
current_queue = self.prompt_queue.get_current_queue()
|
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||||
queue_info['queue_running'] = current_queue[0]
|
queue_info['queue_running'] = current_queue[0]
|
||||||
queue_info['queue_pending'] = current_queue[1]
|
queue_info['queue_pending'] = current_queue[1]
|
||||||
return web.json_response(queue_info)
|
return web.json_response(queue_info)
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
# ComfyUI API Testing
|
|
||||||
|
|
||||||
This directory contains tests for validating the ComfyUI OpenAPI specification against a running instance of ComfyUI.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
1. Install the required dependencies:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Make sure you have a running instance of ComfyUI (default: http://127.0.0.1:8188)
|
|
||||||
|
|
||||||
## Running the Tests
|
|
||||||
|
|
||||||
Run all tests with pytest:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd tests-api
|
|
||||||
pytest
|
|
||||||
```
|
|
||||||
|
|
||||||
Run specific test files:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest test_spec_validation.py
|
|
||||||
pytest test_endpoint_existence.py
|
|
||||||
pytest test_schema_validation.py
|
|
||||||
pytest test_api_by_tag.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Run tests with more verbose output:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest -v
|
|
||||||
```
|
|
||||||
|
|
||||||
## Test Categories
|
|
||||||
|
|
||||||
The tests are organized into several categories:
|
|
||||||
|
|
||||||
1. **Spec Validation**: Validates that the OpenAPI specification is valid.
|
|
||||||
2. **Endpoint Existence**: Tests that the endpoints defined in the spec exist on the server.
|
|
||||||
3. **Schema Validation**: Tests that the server responses match the schemas defined in the spec.
|
|
||||||
4. **Tag-Based Tests**: Tests that the API's tag organization is consistent.
|
|
||||||
|
|
||||||
## Using a Different Server
|
|
||||||
|
|
||||||
By default, the tests connect to `http://127.0.0.1:8188`. To test against a different server, set the `COMFYUI_SERVER_URL` environment variable:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
COMFYUI_SERVER_URL=http://example.com:8188 pytest
|
|
||||||
```
|
|
||||||
|
|
||||||
## Test Structure
|
|
||||||
|
|
||||||
- `conftest.py`: Contains pytest fixtures used by the tests.
|
|
||||||
- `utils/`: Contains utility functions for working with the OpenAPI spec.
|
|
||||||
- `test_*.py`: The actual test files.
|
|
||||||
- `resources/`: Contains resources used by the tests (e.g., sample workflows).
|
|
||||||
|
|
||||||
## Extending the Tests
|
|
||||||
|
|
||||||
To add new tests:
|
|
||||||
|
|
||||||
1. For testing new endpoints, add them to the appropriate test file based on their category.
|
|
||||||
2. For testing more complex functionality, create a new test file following the established patterns.
|
|
||||||
|
|
||||||
## Notes
|
|
||||||
|
|
||||||
- Tests that require a running server will be skipped if the server is not available.
|
|
||||||
- Some tests may fail if the server doesn't match the specification exactly.
|
|
||||||
- The tests don't modify any data on the server (they're read-only).
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
"""
|
|
||||||
Test fixtures for API testing
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
import requests
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, Generator, Optional
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Default server configuration
|
|
||||||
DEFAULT_SERVER_URL = "http://127.0.0.1:8188"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def api_spec_path() -> str:
|
|
||||||
"""
|
|
||||||
Get the path to the OpenAPI specification file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the OpenAPI specification file
|
|
||||||
"""
|
|
||||||
return os.path.abspath(os.path.join(
|
|
||||||
os.path.dirname(__file__),
|
|
||||||
"..",
|
|
||||||
"openapi.yaml"
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def api_spec(api_spec_path: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Load the OpenAPI specification
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec_path: Path to the spec file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed OpenAPI specification
|
|
||||||
"""
|
|
||||||
with open(api_spec_path, 'r') as f:
|
|
||||||
return yaml.safe_load(f)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def base_url() -> str:
|
|
||||||
"""
|
|
||||||
Get the base URL for the API server
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Base URL string
|
|
||||||
"""
|
|
||||||
# Allow overriding via environment variable
|
|
||||||
return os.environ.get("COMFYUI_SERVER_URL", DEFAULT_SERVER_URL)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def server_available(base_url: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the server is available
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_url: Base URL for the API
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the server is available, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = requests.get(base_url, timeout=2)
|
|
||||||
return response.status_code == 200
|
|
||||||
except requests.RequestException:
|
|
||||||
logger.warning(f"Server at {base_url} is not available")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def api_client(base_url: str) -> Generator[Optional[requests.Session], None, None]:
|
|
||||||
"""
|
|
||||||
Create a requests session for API testing
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_url: Base URL for the API
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Requests session configured for the API
|
|
||||||
"""
|
|
||||||
session = requests.Session()
|
|
||||||
|
|
||||||
# Helper function to construct URLs
|
|
||||||
def get_url(path: str) -> str:
|
|
||||||
return urljoin(base_url, path)
|
|
||||||
|
|
||||||
# Add url helper to the session
|
|
||||||
session.get_url = get_url # type: ignore
|
|
||||||
|
|
||||||
yield session
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def api_get_json(api_client: requests.Session):
|
|
||||||
"""
|
|
||||||
Helper fixture for making GET requests and parsing JSON responses
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_client: API client session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Function that makes GET requests and returns JSON
|
|
||||||
"""
|
|
||||||
def _get_json(path: str, **kwargs):
|
|
||||||
url = api_client.get_url(path) # type: ignore
|
|
||||||
response = api_client.get(url, **kwargs)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
try:
|
|
||||||
return response.json()
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
return _get_json
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def require_server(server_available):
|
|
||||||
"""
|
|
||||||
Skip tests if server is not available
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_available: Whether the server is available
|
|
||||||
"""
|
|
||||||
if not server_available:
|
|
||||||
pytest.skip("Server is not available")
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
pytest>=7.0.0
|
|
||||||
pytest-asyncio>=0.21.0
|
|
||||||
openapi-spec-validator>=0.5.0
|
|
||||||
jsonschema>=4.17.0
|
|
||||||
requests>=2.28.0
|
|
||||||
pyyaml>=6.0.0
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for API endpoints grouped by tags
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from typing import Dict, Any, Set
|
|
||||||
|
|
||||||
# Use a direct import with the full path
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
sys.path.insert(0, current_dir)
|
|
||||||
|
|
||||||
# Define functions inline to avoid import issues
|
|
||||||
def get_all_endpoints(spec):
|
|
||||||
"""
|
|
||||||
Extract all endpoints from an OpenAPI spec
|
|
||||||
"""
|
|
||||||
endpoints = []
|
|
||||||
|
|
||||||
for path, path_item in spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
endpoints.append({
|
|
||||||
'path': path,
|
|
||||||
'method': method.lower(),
|
|
||||||
'tags': operation.get('tags', []),
|
|
||||||
'operation_id': operation.get('operationId', ''),
|
|
||||||
'summary': operation.get('summary', '')
|
|
||||||
})
|
|
||||||
|
|
||||||
return endpoints
|
|
||||||
|
|
||||||
def get_all_tags(spec):
|
|
||||||
"""
|
|
||||||
Get all tags used in the API spec
|
|
||||||
"""
|
|
||||||
tags = set()
|
|
||||||
|
|
||||||
for path_item in spec['paths'].values():
|
|
||||||
for operation in path_item.values():
|
|
||||||
if isinstance(operation, dict) and 'tags' in operation:
|
|
||||||
tags.update(operation['tags'])
|
|
||||||
|
|
||||||
return tags
|
|
||||||
|
|
||||||
def extract_endpoints_by_tag(spec, tag):
|
|
||||||
"""
|
|
||||||
Extract all endpoints with a specific tag
|
|
||||||
"""
|
|
||||||
endpoints = []
|
|
||||||
|
|
||||||
for path, path_item in spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if tag in operation.get('tags', []):
|
|
||||||
endpoints.append({
|
|
||||||
'path': path,
|
|
||||||
'method': method.lower(),
|
|
||||||
'operation_id': operation.get('operationId', ''),
|
|
||||||
'summary': operation.get('summary', '')
|
|
||||||
})
|
|
||||||
|
|
||||||
return endpoints
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def api_tags(api_spec: Dict[str, Any]) -> Set[str]:
|
|
||||||
"""
|
|
||||||
Get all tags from the API spec
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Set of tag names
|
|
||||||
"""
|
|
||||||
return get_all_tags(api_spec)
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_has_tags(api_tags: Set[str]):
|
|
||||||
"""
|
|
||||||
Test that the API has defined tags
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_tags: Set of tags
|
|
||||||
"""
|
|
||||||
assert len(api_tags) > 0, "API spec should have at least one tag"
|
|
||||||
|
|
||||||
# Log the tags
|
|
||||||
logger.info(f"API spec has the following tags: {sorted(api_tags)}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tag", [
|
|
||||||
"workflow",
|
|
||||||
"image",
|
|
||||||
"model",
|
|
||||||
"node",
|
|
||||||
"system"
|
|
||||||
])
|
|
||||||
def test_core_tags_exist(api_tags: Set[str], tag: str):
|
|
||||||
"""
|
|
||||||
Test that core tags exist in the API spec
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_tags: Set of tags
|
|
||||||
tag: Tag to check
|
|
||||||
"""
|
|
||||||
assert tag in api_tags, f"API spec should have '{tag}' tag"
|
|
||||||
|
|
||||||
|
|
||||||
def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the 'workflow' tag has appropriate endpoints
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
endpoints = extract_endpoints_by_tag(api_spec, "workflow")
|
|
||||||
|
|
||||||
assert len(endpoints) > 0, "No endpoints found with 'workflow' tag"
|
|
||||||
|
|
||||||
# Check for key workflow endpoints
|
|
||||||
endpoint_paths = [e["path"] for e in endpoints]
|
|
||||||
assert "/prompt" in endpoint_paths, "Workflow tag should include /prompt endpoint"
|
|
||||||
|
|
||||||
# Log the endpoints
|
|
||||||
logger.info(f"Found {len(endpoints)} endpoints with 'workflow' tag:")
|
|
||||||
for e in endpoints:
|
|
||||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_tag_has_endpoints(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the 'image' tag has appropriate endpoints
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
endpoints = extract_endpoints_by_tag(api_spec, "image")
|
|
||||||
|
|
||||||
assert len(endpoints) > 0, "No endpoints found with 'image' tag"
|
|
||||||
|
|
||||||
# Check for key image endpoints
|
|
||||||
endpoint_paths = [e["path"] for e in endpoints]
|
|
||||||
assert "/upload/image" in endpoint_paths, "Image tag should include /upload/image endpoint"
|
|
||||||
assert "/view" in endpoint_paths, "Image tag should include /view endpoint"
|
|
||||||
|
|
||||||
# Log the endpoints
|
|
||||||
logger.info(f"Found {len(endpoints)} endpoints with 'image' tag:")
|
|
||||||
for e in endpoints:
|
|
||||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_tag_has_endpoints(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the 'model' tag has appropriate endpoints
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
endpoints = extract_endpoints_by_tag(api_spec, "model")
|
|
||||||
|
|
||||||
assert len(endpoints) > 0, "No endpoints found with 'model' tag"
|
|
||||||
|
|
||||||
# Check for key model endpoints
|
|
||||||
endpoint_paths = [e["path"] for e in endpoints]
|
|
||||||
assert "/models" in endpoint_paths, "Model tag should include /models endpoint"
|
|
||||||
|
|
||||||
# Log the endpoints
|
|
||||||
logger.info(f"Found {len(endpoints)} endpoints with 'model' tag:")
|
|
||||||
for e in endpoints:
|
|
||||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_tag_has_endpoints(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the 'node' tag has appropriate endpoints
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
endpoints = extract_endpoints_by_tag(api_spec, "node")
|
|
||||||
|
|
||||||
assert len(endpoints) > 0, "No endpoints found with 'node' tag"
|
|
||||||
|
|
||||||
# Check for key node endpoints
|
|
||||||
endpoint_paths = [e["path"] for e in endpoints]
|
|
||||||
assert "/object_info" in endpoint_paths, "Node tag should include /object_info endpoint"
|
|
||||||
|
|
||||||
# Log the endpoints
|
|
||||||
logger.info(f"Found {len(endpoints)} endpoints with 'node' tag:")
|
|
||||||
for e in endpoints:
|
|
||||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_tag_has_endpoints(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the 'system' tag has appropriate endpoints
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
endpoints = extract_endpoints_by_tag(api_spec, "system")
|
|
||||||
|
|
||||||
assert len(endpoints) > 0, "No endpoints found with 'system' tag"
|
|
||||||
|
|
||||||
# Check for key system endpoints
|
|
||||||
endpoint_paths = [e["path"] for e in endpoints]
|
|
||||||
assert "/system_stats" in endpoint_paths, "System tag should include /system_stats endpoint"
|
|
||||||
|
|
||||||
# Log the endpoints
|
|
||||||
logger.info(f"Found {len(endpoints)} endpoints with 'system' tag:")
|
|
||||||
for e in endpoints:
|
|
||||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the 'internal' tag has appropriate endpoints
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
endpoints = extract_endpoints_by_tag(api_spec, "internal")
|
|
||||||
|
|
||||||
assert len(endpoints) > 0, "No endpoints found with 'internal' tag"
|
|
||||||
|
|
||||||
# Check for key internal endpoints
|
|
||||||
endpoint_paths = [e["path"] for e in endpoints]
|
|
||||||
assert "/internal/logs" in endpoint_paths, "Internal tag should include /internal/logs endpoint"
|
|
||||||
|
|
||||||
# Log the endpoints
|
|
||||||
logger.info(f"Found {len(endpoints)} endpoints with 'internal' tag:")
|
|
||||||
for e in endpoints:
|
|
||||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_operation_ids_match_tag(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that operation IDs follow a consistent pattern with their tag
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
failures = []
|
|
||||||
|
|
||||||
for path, path_item in api_spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
if 'operationId' in operation and 'tags' in operation and operation['tags']:
|
|
||||||
op_id = operation['operationId']
|
|
||||||
primary_tag = operation['tags'][0].lower()
|
|
||||||
|
|
||||||
# Check if operationId starts with primary tag prefix
|
|
||||||
# This is a common convention, but might need adjusting
|
|
||||||
if not (op_id.startswith(primary_tag) or
|
|
||||||
any(op_id.lower().startswith(f"{tag.lower()}") for tag in operation['tags'])):
|
|
||||||
failures.append({
|
|
||||||
'path': path,
|
|
||||||
'method': method,
|
|
||||||
'operationId': op_id,
|
|
||||||
'primary_tag': primary_tag
|
|
||||||
})
|
|
||||||
|
|
||||||
# Log failures for diagnosis but don't fail the test
|
|
||||||
# as this is a style/convention check
|
|
||||||
if failures:
|
|
||||||
logger.warning(f"Found {len(failures)} operationIds that don't align with their tags:")
|
|
||||||
for f in failures:
|
|
||||||
logger.warning(f" {f['method'].upper()} {f['path']} - operationId: {f['operationId']}, primary tag: {f['primary_tag']}")
|
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for endpoint existence and basic response codes
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from typing import Dict, Any, List
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
# Use a direct import with the full path
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
sys.path.insert(0, current_dir)
|
|
||||||
|
|
||||||
# Define get_all_endpoints function inline to avoid import issues
|
|
||||||
def get_all_endpoints(spec):
|
|
||||||
"""
|
|
||||||
Extract all endpoints from an OpenAPI spec
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with path, method, and tags for each endpoint
|
|
||||||
"""
|
|
||||||
endpoints = []
|
|
||||||
|
|
||||||
for path, path_item in spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
endpoints.append({
|
|
||||||
'path': path,
|
|
||||||
'method': method.lower(),
|
|
||||||
'tags': operation.get('tags', []),
|
|
||||||
'operation_id': operation.get('operationId', ''),
|
|
||||||
'summary': operation.get('summary', '')
|
|
||||||
})
|
|
||||||
|
|
||||||
return endpoints
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get all endpoints from the API spec
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of endpoint information
|
|
||||||
"""
|
|
||||||
return get_all_endpoints(api_spec)
|
|
||||||
|
|
||||||
|
|
||||||
def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]):
|
|
||||||
"""
|
|
||||||
Test that endpoints are defined in the spec
|
|
||||||
|
|
||||||
Args:
|
|
||||||
all_endpoints: List of endpoint information
|
|
||||||
"""
|
|
||||||
# Simple check that we have endpoints defined
|
|
||||||
assert len(all_endpoints) > 0, "No endpoints defined in the OpenAPI spec"
|
|
||||||
|
|
||||||
# Log the endpoints for informational purposes
|
|
||||||
logger.info(f"Found {len(all_endpoints)} endpoints in the OpenAPI spec")
|
|
||||||
for endpoint in all_endpoints:
|
|
||||||
logger.info(f"{endpoint['method'].upper()} {endpoint['path']} - {endpoint['summary']}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("endpoint_path", [
|
|
||||||
"/", # Root path
|
|
||||||
"/prompt", # Get prompt info
|
|
||||||
"/queue", # Get queue
|
|
||||||
"/models", # Get model types
|
|
||||||
"/object_info", # Get node info
|
|
||||||
"/system_stats" # Get system stats
|
|
||||||
])
|
|
||||||
def test_basic_get_endpoints(require_server, api_client, endpoint_path: str):
|
|
||||||
"""
|
|
||||||
Test that basic GET endpoints exist and respond
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
endpoint_path: Path to test
|
|
||||||
"""
|
|
||||||
url = api_client.get_url(endpoint_path) # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = api_client.get(url)
|
|
||||||
|
|
||||||
# We're just checking that the endpoint exists and returns some kind of response
|
|
||||||
# Not necessarily a 200 status code
|
|
||||||
assert response.status_code not in [404, 405], f"Endpoint {endpoint_path} does not exist"
|
|
||||||
|
|
||||||
logger.info(f"Endpoint {endpoint_path} exists with status code {response.status_code}")
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_websocket_endpoint_exists(require_server, base_url: str):
|
|
||||||
"""
|
|
||||||
Test that the WebSocket endpoint exists
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
base_url: Base server URL
|
|
||||||
"""
|
|
||||||
ws_url = urljoin(base_url, "/ws")
|
|
||||||
|
|
||||||
# For WebSocket, we can't use a normal GET request
|
|
||||||
# Instead, we make a HEAD request to check if the endpoint exists
|
|
||||||
try:
|
|
||||||
response = requests.head(ws_url)
|
|
||||||
|
|
||||||
# WebSocket endpoints often return a 400 Bad Request for HEAD requests
|
|
||||||
# but a 404 would indicate the endpoint doesn't exist
|
|
||||||
assert response.status_code != 404, "WebSocket endpoint /ws does not exist"
|
|
||||||
|
|
||||||
logger.info(f"WebSocket endpoint exists with status code {response.status_code}")
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request to WebSocket endpoint failed: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_models_folder_endpoint(require_server, api_client):
|
|
||||||
"""
|
|
||||||
Test that the /models/{folder} endpoint exists and responds
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
"""
|
|
||||||
# First get available model types
|
|
||||||
models_url = api_client.get_url("/models") # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
models_response = api_client.get(models_url)
|
|
||||||
assert models_response.status_code == 200, "Failed to get model types"
|
|
||||||
|
|
||||||
model_types = models_response.json()
|
|
||||||
|
|
||||||
# Skip if no model types available
|
|
||||||
if not model_types:
|
|
||||||
pytest.skip("No model types available to test")
|
|
||||||
|
|
||||||
# Test with the first model type
|
|
||||||
model_type = model_types[0]
|
|
||||||
models_folder_url = api_client.get_url(f"/models/{model_type}") # type: ignore
|
|
||||||
|
|
||||||
folder_response = api_client.get(models_folder_url)
|
|
||||||
|
|
||||||
# We're just checking that the endpoint exists
|
|
||||||
assert folder_response.status_code != 404, f"Endpoint /models/{model_type} does not exist"
|
|
||||||
|
|
||||||
logger.info(f"Endpoint /models/{model_type} exists with status code {folder_response.status_code}")
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request failed: {str(e)}")
|
|
||||||
except (ValueError, KeyError, IndexError) as e:
|
|
||||||
pytest.fail(f"Failed to process response: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_object_info_node_endpoint(require_server, api_client):
|
|
||||||
"""
|
|
||||||
Test that the /object_info/{node_class} endpoint exists and responds
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
"""
|
|
||||||
# First get available node classes
|
|
||||||
objects_url = api_client.get_url("/object_info") # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
objects_response = api_client.get(objects_url)
|
|
||||||
assert objects_response.status_code == 200, "Failed to get object info"
|
|
||||||
|
|
||||||
node_classes = objects_response.json()
|
|
||||||
|
|
||||||
# Skip if no node classes available
|
|
||||||
if not node_classes:
|
|
||||||
pytest.skip("No node classes available to test")
|
|
||||||
|
|
||||||
# Test with the first node class
|
|
||||||
node_class = next(iter(node_classes.keys()))
|
|
||||||
node_url = api_client.get_url(f"/object_info/{node_class}") # type: ignore
|
|
||||||
|
|
||||||
node_response = api_client.get(node_url)
|
|
||||||
|
|
||||||
# We're just checking that the endpoint exists
|
|
||||||
assert node_response.status_code != 404, f"Endpoint /object_info/{node_class} does not exist"
|
|
||||||
|
|
||||||
logger.info(f"Endpoint /object_info/{node_class} exists with status code {node_response.status_code}")
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request failed: {str(e)}")
|
|
||||||
except (ValueError, KeyError, StopIteration) as e:
|
|
||||||
pytest.fail(f"Failed to process response: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_internal_endpoints_exist(require_server, api_client):
|
|
||||||
"""
|
|
||||||
Test that internal endpoints exist
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
"""
|
|
||||||
internal_endpoints = [
|
|
||||||
"/internal/logs",
|
|
||||||
"/internal/logs/raw",
|
|
||||||
"/internal/folder_paths",
|
|
||||||
"/internal/files/output"
|
|
||||||
]
|
|
||||||
|
|
||||||
for endpoint in internal_endpoints:
|
|
||||||
url = api_client.get_url(endpoint) # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = api_client.get(url)
|
|
||||||
|
|
||||||
# We're just checking that the endpoint exists
|
|
||||||
assert response.status_code != 404, f"Endpoint {endpoint} does not exist"
|
|
||||||
|
|
||||||
logger.info(f"Endpoint {endpoint} exists with status code {response.status_code}")
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.warning(f"Request to {endpoint} failed: {str(e)}")
|
|
||||||
# Don't fail the test as internal endpoints might be restricted
|
|
||||||
@@ -1,440 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for validating API responses against OpenAPI schema
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
# Use a direct import with the full path
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
sys.path.insert(0, current_dir)
|
|
||||||
|
|
||||||
# Define validation functions inline to avoid import issues
|
|
||||||
def get_endpoint_schema(
|
|
||||||
spec,
|
|
||||||
path,
|
|
||||||
method,
|
|
||||||
status_code = '200'
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Extract response schema for a specific endpoint from OpenAPI spec
|
|
||||||
"""
|
|
||||||
method = method.lower()
|
|
||||||
|
|
||||||
# Handle path not found
|
|
||||||
if path not in spec['paths']:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle method not found
|
|
||||||
if method not in spec['paths'][path]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle status code not found
|
|
||||||
responses = spec['paths'][path][method].get('responses', {})
|
|
||||||
if status_code not in responses:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle no content defined
|
|
||||||
if 'content' not in responses[status_code]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get schema from first content type
|
|
||||||
content_types = responses[status_code]['content']
|
|
||||||
first_content_type = next(iter(content_types))
|
|
||||||
|
|
||||||
if 'schema' not in content_types[first_content_type]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return content_types[first_content_type]['schema']
|
|
||||||
|
|
||||||
def resolve_schema_refs(schema, spec):
|
|
||||||
"""
|
|
||||||
Resolve $ref references in a schema
|
|
||||||
"""
|
|
||||||
if not isinstance(schema, dict):
|
|
||||||
return schema
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for key, value in schema.items():
|
|
||||||
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
|
|
||||||
# Handle reference
|
|
||||||
ref_path = value[2:].split('/')
|
|
||||||
ref_value = spec
|
|
||||||
for path_part in ref_path:
|
|
||||||
ref_value = ref_value.get(path_part, {})
|
|
||||||
|
|
||||||
# Recursively resolve any refs in the referenced schema
|
|
||||||
ref_value = resolve_schema_refs(ref_value, spec)
|
|
||||||
result.update(ref_value)
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
# Recursively resolve refs in nested dictionaries
|
|
||||||
result[key] = resolve_schema_refs(value, spec)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
# Recursively resolve refs in list items
|
|
||||||
result[key] = [
|
|
||||||
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
|
|
||||||
for item in value
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Pass through other values
|
|
||||||
result[key] = value
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def validate_response(
|
|
||||||
response_data,
|
|
||||||
spec,
|
|
||||||
path,
|
|
||||||
method,
|
|
||||||
status_code = '200'
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Validate a response against the OpenAPI schema
|
|
||||||
"""
|
|
||||||
schema = get_endpoint_schema(spec, path, method, status_code)
|
|
||||||
|
|
||||||
if schema is None:
|
|
||||||
return {
|
|
||||||
'valid': False,
|
|
||||||
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Resolve any $ref in the schema
|
|
||||||
resolved_schema = resolve_schema_refs(schema, spec)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import jsonschema
|
|
||||||
jsonschema.validate(instance=response_data, schema=resolved_schema)
|
|
||||||
return {'valid': True, 'errors': []}
|
|
||||||
except jsonschema.exceptions.ValidationError as e:
|
|
||||||
# Extract more detailed error information
|
|
||||||
path = ".".join(str(p) for p in e.path) if e.path else "root"
|
|
||||||
instance = e.instance if not isinstance(e.instance, dict) else "..."
|
|
||||||
schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown"
|
|
||||||
|
|
||||||
detailed_error = (
|
|
||||||
f"Validation error at path: {path}\n"
|
|
||||||
f"Schema path: {schema_path}\n"
|
|
||||||
f"Error message: {e.message}\n"
|
|
||||||
f"Failed instance: {instance}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {'valid': False, 'errors': [detailed_error]}
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("endpoint_path,method", [
|
|
||||||
("/system_stats", "get"),
|
|
||||||
("/prompt", "get"),
|
|
||||||
("/queue", "get"),
|
|
||||||
("/models", "get"),
|
|
||||||
("/embeddings", "get")
|
|
||||||
])
|
|
||||||
def test_response_schema_validation(
|
|
||||||
require_server,
|
|
||||||
api_client,
|
|
||||||
api_spec: Dict[str, Any],
|
|
||||||
endpoint_path: str,
|
|
||||||
method: str
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Test that API responses match the defined schema
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
endpoint_path: Path to test
|
|
||||||
method: HTTP method to test
|
|
||||||
"""
|
|
||||||
url = api_client.get_url(endpoint_path) # type: ignore
|
|
||||||
|
|
||||||
# Skip if no schema defined
|
|
||||||
schema = get_endpoint_schema(api_spec, endpoint_path, method)
|
|
||||||
if not schema:
|
|
||||||
pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if method.lower() == "get":
|
|
||||||
response = api_client.get(url)
|
|
||||||
else:
|
|
||||||
pytest.skip(f"Method {method} not implemented for automated testing")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Skip if response is not 200
|
|
||||||
if response.status_code != 200:
|
|
||||||
pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Skip if response is not JSON
|
|
||||||
try:
|
|
||||||
response_data = response.json()
|
|
||||||
except ValueError:
|
|
||||||
pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Validate the response
|
|
||||||
validation_result = validate_response(
|
|
||||||
response_data,
|
|
||||||
api_spec,
|
|
||||||
endpoint_path,
|
|
||||||
method
|
|
||||||
)
|
|
||||||
|
|
||||||
if validation_result['valid']:
|
|
||||||
logger.info(f"Response from {method.upper()} {endpoint_path} matches schema")
|
|
||||||
else:
|
|
||||||
for error in validation_result['errors']:
|
|
||||||
logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}")
|
|
||||||
|
|
||||||
assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema"
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test the system_stats endpoint response in detail
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
url = api_client.get_url("/system_stats") # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = api_client.get(url)
|
|
||||||
|
|
||||||
assert response.status_code == 200, "Failed to get system stats"
|
|
||||||
|
|
||||||
# Parse response
|
|
||||||
stats = response.json()
|
|
||||||
|
|
||||||
# Validate high-level structure
|
|
||||||
assert 'system' in stats, "Response missing 'system' field"
|
|
||||||
assert 'devices' in stats, "Response missing 'devices' field"
|
|
||||||
|
|
||||||
# Validate system fields
|
|
||||||
system = stats['system']
|
|
||||||
assert 'os' in system, "System missing 'os' field"
|
|
||||||
assert 'ram_total' in system, "System missing 'ram_total' field"
|
|
||||||
assert 'ram_free' in system, "System missing 'ram_free' field"
|
|
||||||
assert 'comfyui_version' in system, "System missing 'comfyui_version' field"
|
|
||||||
|
|
||||||
# Validate devices fields
|
|
||||||
devices = stats['devices']
|
|
||||||
assert isinstance(devices, list), "Devices should be a list"
|
|
||||||
|
|
||||||
if devices:
|
|
||||||
device = devices[0]
|
|
||||||
assert 'name' in device, "Device missing 'name' field"
|
|
||||||
assert 'type' in device, "Device missing 'type' field"
|
|
||||||
assert 'vram_total' in device, "Device missing 'vram_total' field"
|
|
||||||
assert 'vram_free' in device, "Device missing 'vram_free' field"
|
|
||||||
|
|
||||||
# Perform schema validation
|
|
||||||
validation_result = validate_response(
|
|
||||||
stats,
|
|
||||||
api_spec,
|
|
||||||
"/system_stats",
|
|
||||||
"get"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print detailed error if validation fails
|
|
||||||
if not validation_result['valid']:
|
|
||||||
for error in validation_result['errors']:
|
|
||||||
logger.error(f"Validation error for /system_stats: {error}")
|
|
||||||
|
|
||||||
# Print schema details for debugging
|
|
||||||
schema = get_endpoint_schema(api_spec, "/system_stats", "get")
|
|
||||||
if schema:
|
|
||||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
|
||||||
|
|
||||||
# Print sample of the response
|
|
||||||
logger.error(f"Response:\n{json.dumps(stats, indent=2)}")
|
|
||||||
|
|
||||||
assert validation_result['valid'], "System stats response does not match schema"
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request to /system_stats failed: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test the models endpoint response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
url = api_client.get_url("/models") # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = api_client.get(url)
|
|
||||||
|
|
||||||
assert response.status_code == 200, "Failed to get models"
|
|
||||||
|
|
||||||
# Parse response
|
|
||||||
models = response.json()
|
|
||||||
|
|
||||||
# Validate it's a list
|
|
||||||
assert isinstance(models, list), "Models response should be a list"
|
|
||||||
|
|
||||||
# Each item should be a string
|
|
||||||
for model in models:
|
|
||||||
assert isinstance(model, str), "Each model type should be a string"
|
|
||||||
|
|
||||||
# Perform schema validation
|
|
||||||
validation_result = validate_response(
|
|
||||||
models,
|
|
||||||
api_spec,
|
|
||||||
"/models",
|
|
||||||
"get"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print detailed error if validation fails
|
|
||||||
if not validation_result['valid']:
|
|
||||||
for error in validation_result['errors']:
|
|
||||||
logger.error(f"Validation error for /models: {error}")
|
|
||||||
|
|
||||||
# Print schema details for debugging
|
|
||||||
schema = get_endpoint_schema(api_spec, "/models", "get")
|
|
||||||
if schema:
|
|
||||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
|
||||||
|
|
||||||
# Print response
|
|
||||||
sample_models = models[:5] if isinstance(models, list) else models
|
|
||||||
logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}")
|
|
||||||
|
|
||||||
assert validation_result['valid'], "Models response does not match schema"
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request to /models failed: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test the object_info endpoint response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
url = api_client.get_url("/object_info") # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = api_client.get(url)
|
|
||||||
|
|
||||||
assert response.status_code == 200, "Failed to get object info"
|
|
||||||
|
|
||||||
# Parse response
|
|
||||||
objects = response.json()
|
|
||||||
|
|
||||||
# Validate it's an object
|
|
||||||
assert isinstance(objects, dict), "Object info response should be an object"
|
|
||||||
|
|
||||||
# Check if we have any objects
|
|
||||||
if objects:
|
|
||||||
# Get the first object
|
|
||||||
first_obj_name = next(iter(objects.keys()))
|
|
||||||
first_obj = objects[first_obj_name]
|
|
||||||
|
|
||||||
# Validate first object has required fields
|
|
||||||
assert 'input' in first_obj, "Object missing 'input' field"
|
|
||||||
assert 'output' in first_obj, "Object missing 'output' field"
|
|
||||||
assert 'name' in first_obj, "Object missing 'name' field"
|
|
||||||
|
|
||||||
# Perform schema validation
|
|
||||||
validation_result = validate_response(
|
|
||||||
objects,
|
|
||||||
api_spec,
|
|
||||||
"/object_info",
|
|
||||||
"get"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print detailed error if validation fails
|
|
||||||
if not validation_result['valid']:
|
|
||||||
for error in validation_result['errors']:
|
|
||||||
logger.error(f"Validation error for /object_info: {error}")
|
|
||||||
|
|
||||||
# Print schema details for debugging
|
|
||||||
schema = get_endpoint_schema(api_spec, "/object_info", "get")
|
|
||||||
if schema:
|
|
||||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
|
||||||
|
|
||||||
# Also print a small sample of the response
|
|
||||||
sample = dict(list(objects.items())[:1]) if objects else {}
|
|
||||||
logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}")
|
|
||||||
|
|
||||||
assert validation_result['valid'], "Object info response does not match schema"
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request to /object_info failed: {str(e)}")
|
|
||||||
except (KeyError, StopIteration) as e:
|
|
||||||
pytest.fail(f"Failed to process response: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test the queue endpoint response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_server: Fixture that skips if server is not available
|
|
||||||
api_client: API client fixture
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
url = api_client.get_url("/queue") # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = api_client.get(url)
|
|
||||||
|
|
||||||
assert response.status_code == 200, "Failed to get queue"
|
|
||||||
|
|
||||||
# Parse response
|
|
||||||
queue = response.json()
|
|
||||||
|
|
||||||
# Validate structure
|
|
||||||
assert 'queue_running' in queue, "Queue missing 'queue_running' field"
|
|
||||||
assert 'queue_pending' in queue, "Queue missing 'queue_pending' field"
|
|
||||||
|
|
||||||
# Each should be a list
|
|
||||||
assert isinstance(queue['queue_running'], list), "queue_running should be a list"
|
|
||||||
assert isinstance(queue['queue_pending'], list), "queue_pending should be a list"
|
|
||||||
|
|
||||||
# Perform schema validation
|
|
||||||
validation_result = validate_response(
|
|
||||||
queue,
|
|
||||||
api_spec,
|
|
||||||
"/queue",
|
|
||||||
"get"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print detailed error if validation fails
|
|
||||||
if not validation_result['valid']:
|
|
||||||
for error in validation_result['errors']:
|
|
||||||
logger.error(f"Validation error for /queue: {error}")
|
|
||||||
|
|
||||||
# Print schema details for debugging
|
|
||||||
schema = get_endpoint_schema(api_spec, "/queue", "get")
|
|
||||||
if schema:
|
|
||||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
|
||||||
|
|
||||||
# Print response
|
|
||||||
logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}")
|
|
||||||
|
|
||||||
assert validation_result['valid'], "Queue response does not match schema"
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
pytest.fail(f"Request to /queue failed: {str(e)}")
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for validating the OpenAPI specification
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from openapi_spec_validator import validate_spec
|
|
||||||
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
def test_openapi_spec_is_valid(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the OpenAPI specification is valid
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
validate_spec(api_spec)
|
|
||||||
except OpenAPISpecValidatorError as e:
|
|
||||||
pytest.fail(f"OpenAPI spec validation failed: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_has_info(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the OpenAPI spec has the required info section
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
assert 'info' in api_spec, "Spec must have info section"
|
|
||||||
assert 'title' in api_spec['info'], "Info must have title"
|
|
||||||
assert 'version' in api_spec['info'], "Info must have version"
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_has_paths(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the OpenAPI spec has paths defined
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
assert 'paths' in api_spec, "Spec must have paths section"
|
|
||||||
assert len(api_spec['paths']) > 0, "Spec must have at least one path"
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_has_components(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that the OpenAPI spec has components defined
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
assert 'components' in api_spec, "Spec must have components section"
|
|
||||||
assert 'schemas' in api_spec['components'], "Components must have schemas"
|
|
||||||
|
|
||||||
|
|
||||||
def test_workflow_endpoints_exist(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that core workflow endpoints are defined
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
assert '/prompt' in api_spec['paths'], "Spec must define /prompt endpoint"
|
|
||||||
assert 'post' in api_spec['paths']['/prompt'], "Spec must define POST /prompt"
|
|
||||||
assert 'get' in api_spec['paths']['/prompt'], "Spec must define GET /prompt"
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_endpoints_exist(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that core image endpoints are defined
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
assert '/upload/image' in api_spec['paths'], "Spec must define /upload/image endpoint"
|
|
||||||
assert '/view' in api_spec['paths'], "Spec must define /view endpoint"
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_endpoints_exist(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that core model endpoints are defined
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
assert '/models' in api_spec['paths'], "Spec must define /models endpoint"
|
|
||||||
assert '/models/{folder}' in api_spec['paths'], "Spec must define /models/{folder} endpoint"
|
|
||||||
|
|
||||||
|
|
||||||
def test_operation_ids_are_unique(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that all operationIds are unique
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
operation_ids = []
|
|
||||||
|
|
||||||
for path, path_item in api_spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
if 'operationId' in operation:
|
|
||||||
operation_ids.append(operation['operationId'])
|
|
||||||
|
|
||||||
# Check for duplicates
|
|
||||||
duplicates = set([op_id for op_id in operation_ids if operation_ids.count(op_id) > 1])
|
|
||||||
assert len(duplicates) == 0, f"Found duplicate operationIds: {duplicates}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_endpoints_have_operation_ids(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that all endpoints have operationIds
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
missing = []
|
|
||||||
|
|
||||||
for path, path_item in api_spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
if 'operationId' not in operation:
|
|
||||||
missing.append(f"{method.upper()} {path}")
|
|
||||||
|
|
||||||
assert len(missing) == 0, f"Found endpoints without operationIds: {missing}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_endpoints_have_tags(api_spec: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Test that all endpoints have tags
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_spec: Loaded OpenAPI spec
|
|
||||||
"""
|
|
||||||
missing = []
|
|
||||||
|
|
||||||
for path, path_item in api_spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
if 'tags' not in operation or not operation['tags']:
|
|
||||||
missing.append(f"{method.upper()} {path}")
|
|
||||||
|
|
||||||
assert len(missing) == 0, f"Found endpoints without tags: {missing}"
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
"""
|
|
||||||
Utilities for working with OpenAPI schemas
|
|
||||||
"""
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
def extract_required_parameters(
|
|
||||||
spec: Dict[str, Any],
|
|
||||||
path: str,
|
|
||||||
method: str
|
|
||||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
||||||
"""
|
|
||||||
Extract required parameters for a specific endpoint
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
path: API path (e.g., '/prompt')
|
|
||||||
method: HTTP method (e.g., 'get', 'post')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (path_params, query_params) containing required parameters
|
|
||||||
"""
|
|
||||||
method = method.lower()
|
|
||||||
path_params = []
|
|
||||||
query_params = []
|
|
||||||
|
|
||||||
# Handle path not found
|
|
||||||
if path not in spec['paths']:
|
|
||||||
return path_params, query_params
|
|
||||||
|
|
||||||
# Handle method not found
|
|
||||||
if method not in spec['paths'][path]:
|
|
||||||
return path_params, query_params
|
|
||||||
|
|
||||||
# Get parameters
|
|
||||||
params = spec['paths'][path][method].get('parameters', [])
|
|
||||||
|
|
||||||
for param in params:
|
|
||||||
if param.get('required', False):
|
|
||||||
if param.get('in') == 'path':
|
|
||||||
path_params.append(param)
|
|
||||||
elif param.get('in') == 'query':
|
|
||||||
query_params.append(param)
|
|
||||||
|
|
||||||
return path_params, query_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_request_body_schema(
|
|
||||||
spec: Dict[str, Any],
|
|
||||||
path: str,
|
|
||||||
method: str
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get request body schema for a specific endpoint
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
path: API path (e.g., '/prompt')
|
|
||||||
method: HTTP method (e.g., 'get', 'post')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Request body schema or None if not found
|
|
||||||
"""
|
|
||||||
method = method.lower()
|
|
||||||
|
|
||||||
# Handle path not found
|
|
||||||
if path not in spec['paths']:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle method not found
|
|
||||||
if method not in spec['paths'][path]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle no request body
|
|
||||||
request_body = spec['paths'][path][method].get('requestBody', {})
|
|
||||||
if not request_body or 'content' not in request_body:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get schema from first content type
|
|
||||||
content_types = request_body['content']
|
|
||||||
first_content_type = next(iter(content_types))
|
|
||||||
|
|
||||||
if 'schema' not in content_types[first_content_type]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return content_types[first_content_type]['schema']
|
|
||||||
|
|
||||||
|
|
||||||
def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Extract all endpoints with a specific tag
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
tag: Tag to filter by
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of endpoint details
|
|
||||||
"""
|
|
||||||
endpoints = []
|
|
||||||
|
|
||||||
for path, path_item in spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if tag in operation.get('tags', []):
|
|
||||||
endpoints.append({
|
|
||||||
'path': path,
|
|
||||||
'method': method.lower(),
|
|
||||||
'operation_id': operation.get('operationId', ''),
|
|
||||||
'summary': operation.get('summary', '')
|
|
||||||
})
|
|
||||||
|
|
||||||
return endpoints
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_tags(spec: Dict[str, Any]) -> Set[str]:
|
|
||||||
"""
|
|
||||||
Get all tags used in the API spec
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Set of tag names
|
|
||||||
"""
|
|
||||||
tags = set()
|
|
||||||
|
|
||||||
for path_item in spec['paths'].values():
|
|
||||||
for operation in path_item.values():
|
|
||||||
if isinstance(operation, dict) and 'tags' in operation:
|
|
||||||
tags.update(operation['tags'])
|
|
||||||
|
|
||||||
return tags
|
|
||||||
|
|
||||||
|
|
||||||
def get_schema_examples(spec: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Extract all examples from component schemas
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping schema names to examples
|
|
||||||
"""
|
|
||||||
examples = {}
|
|
||||||
|
|
||||||
if 'components' not in spec or 'schemas' not in spec['components']:
|
|
||||||
return examples
|
|
||||||
|
|
||||||
for name, schema in spec['components']['schemas'].items():
|
|
||||||
if 'example' in schema:
|
|
||||||
examples[name] = schema['example']
|
|
||||||
|
|
||||||
return examples
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""
|
|
||||||
Utilities for API response validation against OpenAPI spec
|
|
||||||
"""
|
|
||||||
import yaml
|
|
||||||
import jsonschema
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
|
|
||||||
def load_openapi_spec(spec_path: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Load the OpenAPI specification from a YAML file
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec_path: Path to the OpenAPI specification file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing the parsed OpenAPI spec
|
|
||||||
"""
|
|
||||||
with open(spec_path, 'r') as f:
|
|
||||||
return yaml.safe_load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def get_endpoint_schema(
|
|
||||||
spec: Dict[str, Any],
|
|
||||||
path: str,
|
|
||||||
method: str,
|
|
||||||
status_code: str = '200'
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Extract response schema for a specific endpoint from OpenAPI spec
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
path: API path (e.g., '/prompt')
|
|
||||||
method: HTTP method (e.g., 'get', 'post')
|
|
||||||
status_code: HTTP status code to get schema for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Schema dict or None if not found
|
|
||||||
"""
|
|
||||||
method = method.lower()
|
|
||||||
|
|
||||||
# Handle path not found
|
|
||||||
if path not in spec['paths']:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle method not found
|
|
||||||
if method not in spec['paths'][path]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle status code not found
|
|
||||||
responses = spec['paths'][path][method].get('responses', {})
|
|
||||||
if status_code not in responses:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle no content defined
|
|
||||||
if 'content' not in responses[status_code]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get schema from first content type
|
|
||||||
content_types = responses[status_code]['content']
|
|
||||||
first_content_type = next(iter(content_types))
|
|
||||||
|
|
||||||
if 'schema' not in content_types[first_content_type]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return content_types[first_content_type]['schema']
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Resolve $ref references in a schema
|
|
||||||
|
|
||||||
Args:
|
|
||||||
schema: Schema that may contain references
|
|
||||||
spec: Full OpenAPI spec with component definitions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Schema with references resolved
|
|
||||||
"""
|
|
||||||
if not isinstance(schema, dict):
|
|
||||||
return schema
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for key, value in schema.items():
|
|
||||||
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
|
|
||||||
# Handle reference
|
|
||||||
ref_path = value[2:].split('/')
|
|
||||||
ref_value = spec
|
|
||||||
for path_part in ref_path:
|
|
||||||
ref_value = ref_value.get(path_part, {})
|
|
||||||
|
|
||||||
# Recursively resolve any refs in the referenced schema
|
|
||||||
ref_value = resolve_schema_refs(ref_value, spec)
|
|
||||||
result.update(ref_value)
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
# Recursively resolve refs in nested dictionaries
|
|
||||||
result[key] = resolve_schema_refs(value, spec)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
# Recursively resolve refs in list items
|
|
||||||
result[key] = [
|
|
||||||
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
|
|
||||||
for item in value
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Pass through other values
|
|
||||||
result[key] = value
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def validate_response(
|
|
||||||
response_data: Union[Dict[str, Any], List[Any]],
|
|
||||||
spec: Dict[str, Any],
|
|
||||||
path: str,
|
|
||||||
method: str,
|
|
||||||
status_code: str = '200'
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Validate a response against the OpenAPI schema
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_data: Response data to validate
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
path: API path (e.g., '/prompt')
|
|
||||||
method: HTTP method (e.g., 'get', 'post')
|
|
||||||
status_code: HTTP status code to validate against
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with validation result containing:
|
|
||||||
- valid: bool indicating if validation passed
|
|
||||||
- errors: List of validation errors if any
|
|
||||||
"""
|
|
||||||
schema = get_endpoint_schema(spec, path, method, status_code)
|
|
||||||
|
|
||||||
if schema is None:
|
|
||||||
return {
|
|
||||||
'valid': False,
|
|
||||||
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Resolve any $ref in the schema
|
|
||||||
resolved_schema = resolve_schema_refs(schema, spec)
|
|
||||||
|
|
||||||
try:
|
|
||||||
jsonschema.validate(instance=response_data, schema=resolved_schema)
|
|
||||||
return {'valid': True, 'errors': []}
|
|
||||||
except jsonschema.exceptions.ValidationError as e:
|
|
||||||
return {'valid': False, 'errors': [str(e)]}
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Extract all endpoints from an OpenAPI spec
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec: Parsed OpenAPI specification
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with path, method, and tags for each endpoint
|
|
||||||
"""
|
|
||||||
endpoints = []
|
|
||||||
|
|
||||||
for path, path_item in spec['paths'].items():
|
|
||||||
for method, operation in path_item.items():
|
|
||||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
endpoints.append({
|
|
||||||
'path': path,
|
|
||||||
'method': method.lower(),
|
|
||||||
'tags': operation.get('tags', []),
|
|
||||||
'operation_id': operation.get('operationId', ''),
|
|
||||||
'summary': operation.get('summary', '')
|
|
||||||
})
|
|
||||||
|
|
||||||
return endpoints
|
|
||||||
Reference in New Issue
Block a user